faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -23,6 +23,7 @@
23
23
  #include <faiss/impl/AuxIndexStructures.h>
24
24
  #include <faiss/impl/FaissAssert.h>
25
25
  #include <faiss/impl/IDSelector.h>
26
+ #include <faiss/utils/bf16.h>
26
27
  #include <faiss/utils/fp16.h>
27
28
  #include <faiss/utils/utils.h>
28
29
 
@@ -43,7 +44,9 @@ namespace faiss {
43
44
  * that hides the template mess.
44
45
  ********************************************************************/
45
46
 
46
- #ifdef __AVX2__
47
+ #if defined(__AVX512F__) && defined(__F16C__)
48
+ #define USE_AVX512_F16C
49
+ #elif defined(__AVX2__)
47
50
  #ifdef __F16C__
48
51
  #define USE_F16C
49
52
  #else
@@ -52,6 +55,15 @@ namespace faiss {
52
55
  #endif
53
56
  #endif
54
57
 
58
+ #if defined(__aarch64__)
59
+ #if defined(__GNUC__) && __GNUC__ < 8
60
+ #warning \
61
+ "Cannot enable NEON optimizations in scalar quantizer if the compiler is GCC<8"
62
+ #else
63
+ #define USE_NEON
64
+ #endif
65
+ #endif
66
+
55
67
  namespace {
56
68
 
57
69
  typedef ScalarQuantizer::QuantizerType QuantizerType;
@@ -65,42 +77,93 @@ using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
65
77
  */
66
78
 
67
79
  struct Codec8bit {
68
- static void encode_component(float x, uint8_t* code, int i) {
80
+ static FAISS_ALWAYS_INLINE void encode_component(
81
+ float x,
82
+ uint8_t* code,
83
+ int i) {
69
84
  code[i] = (int)(255 * x);
70
85
  }
71
86
 
72
- static float decode_component(const uint8_t* code, int i) {
87
+ static FAISS_ALWAYS_INLINE float decode_component(
88
+ const uint8_t* code,
89
+ int i) {
73
90
  return (code[i] + 0.5f) / 255.0f;
74
91
  }
75
92
 
76
- #ifdef __AVX2__
77
- static __m256 decode_8_components(const uint8_t* code, int i) {
78
- uint64_t c8 = *(uint64_t*)(code + i);
79
- __m128i c4lo = _mm_cvtepu8_epi32(_mm_set1_epi32(c8));
80
- __m128i c4hi = _mm_cvtepu8_epi32(_mm_set1_epi32(c8 >> 32));
81
- // __m256i i8 = _mm256_set_m128i(c4lo, c4hi);
82
- __m256i i8 = _mm256_castsi128_si256(c4lo);
83
- i8 = _mm256_insertf128_si256(i8, c4hi, 1);
84
- __m256 f8 = _mm256_cvtepi32_ps(i8);
85
- __m256 half = _mm256_set1_ps(0.5f);
86
- f8 = _mm256_add_ps(f8, half);
87
- __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
88
- return _mm256_mul_ps(f8, one_255);
93
+ #if defined(__AVX512F__)
94
+ static FAISS_ALWAYS_INLINE __m512
95
+ decode_16_components(const uint8_t* code, int i) {
96
+ const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i));
97
+ const __m512i i32 = _mm512_cvtepu8_epi32(c16);
98
+ const __m512 f16 = _mm512_cvtepi32_ps(i32);
99
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f);
100
+ const __m512 one_255 = _mm512_set1_ps(1.f / 255.f);
101
+ return _mm512_fmadd_ps(f16, one_255, half_one_255);
102
+ }
103
+ #elif defined(__AVX2__)
104
+ static FAISS_ALWAYS_INLINE __m256
105
+ decode_8_components(const uint8_t* code, int i) {
106
+ const uint64_t c8 = *(uint64_t*)(code + i);
107
+
108
+ const __m128i i8 = _mm_set1_epi64x(c8);
109
+ const __m256i i32 = _mm256_cvtepu8_epi32(i8);
110
+ const __m256 f8 = _mm256_cvtepi32_ps(i32);
111
+ const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f);
112
+ const __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
113
+ return _mm256_fmadd_ps(f8, one_255, half_one_255);
114
+ }
115
+ #endif
116
+
117
+ #ifdef USE_NEON
118
+ static FAISS_ALWAYS_INLINE float32x4x2_t
119
+ decode_8_components(const uint8_t* code, int i) {
120
+ float32_t result[8] = {};
121
+ for (size_t j = 0; j < 8; j++) {
122
+ result[j] = decode_component(code, i + j);
123
+ }
124
+ float32x4_t res1 = vld1q_f32(result);
125
+ float32x4_t res2 = vld1q_f32(result + 4);
126
+ return {res1, res2};
89
127
  }
90
128
  #endif
91
129
  };
92
130
 
93
131
  struct Codec4bit {
94
- static void encode_component(float x, uint8_t* code, int i) {
132
+ static FAISS_ALWAYS_INLINE void encode_component(
133
+ float x,
134
+ uint8_t* code,
135
+ int i) {
95
136
  code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
96
137
  }
97
138
 
98
- static float decode_component(const uint8_t* code, int i) {
139
+ static FAISS_ALWAYS_INLINE float decode_component(
140
+ const uint8_t* code,
141
+ int i) {
99
142
  return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
100
143
  }
101
144
 
102
- #ifdef __AVX2__
103
- static __m256 decode_8_components(const uint8_t* code, int i) {
145
+ #if defined(__AVX512F__)
146
+ static FAISS_ALWAYS_INLINE __m512
147
+ decode_16_components(const uint8_t* code, int i) {
148
+ uint64_t c8 = *(uint64_t*)(code + (i >> 1));
149
+ uint64_t mask = 0x0f0f0f0f0f0f0f0f;
150
+ uint64_t c8ev = c8 & mask;
151
+ uint64_t c8od = (c8 >> 4) & mask;
152
+
153
+ __m128i c16 =
154
+ _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od));
155
+ __m256i c8lo = _mm256_cvtepu8_epi32(c16);
156
+ __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8));
157
+ __m512i i16 = _mm512_castsi256_si512(c8lo);
158
+ i16 = _mm512_inserti32x8(i16, c8hi, 1);
159
+ __m512 f16 = _mm512_cvtepi32_ps(i16);
160
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f);
161
+ const __m512 one_255 = _mm512_set1_ps(1.f / 15.f);
162
+ return _mm512_fmadd_ps(f16, one_255, half_one_255);
163
+ }
164
+ #elif defined(__AVX2__)
165
+ static FAISS_ALWAYS_INLINE __m256
166
+ decode_8_components(const uint8_t* code, int i) {
104
167
  uint32_t c4 = *(uint32_t*)(code + (i >> 1));
105
168
  uint32_t mask = 0x0f0f0f0f;
106
169
  uint32_t c4ev = c4 & mask;
@@ -120,10 +183,26 @@ struct Codec4bit {
120
183
  return _mm256_mul_ps(f8, one_255);
121
184
  }
122
185
  #endif
186
+
187
+ #ifdef USE_NEON
188
+ static FAISS_ALWAYS_INLINE float32x4x2_t
189
+ decode_8_components(const uint8_t* code, int i) {
190
+ float32_t result[8] = {};
191
+ for (size_t j = 0; j < 8; j++) {
192
+ result[j] = decode_component(code, i + j);
193
+ }
194
+ float32x4_t res1 = vld1q_f32(result);
195
+ float32x4_t res2 = vld1q_f32(result + 4);
196
+ return {res1, res2};
197
+ }
198
+ #endif
123
199
  };
124
200
 
125
201
  struct Codec6bit {
126
- static void encode_component(float x, uint8_t* code, int i) {
202
+ static FAISS_ALWAYS_INLINE void encode_component(
203
+ float x,
204
+ uint8_t* code,
205
+ int i) {
127
206
  int bits = (int)(x * 63.0);
128
207
  code += (i >> 2) * 3;
129
208
  switch (i & 3) {
@@ -144,7 +223,9 @@ struct Codec6bit {
144
223
  }
145
224
  }
146
225
 
147
- static float decode_component(const uint8_t* code, int i) {
226
+ static FAISS_ALWAYS_INLINE float decode_component(
227
+ const uint8_t* code,
228
+ int i) {
148
229
  uint8_t bits;
149
230
  code += (i >> 2) * 3;
150
231
  switch (i & 3) {
@@ -166,11 +247,60 @@ struct Codec6bit {
166
247
  return (bits + 0.5f) / 63.0f;
167
248
  }
168
249
 
169
- #ifdef __AVX2__
250
+ #if defined(__AVX512F__)
251
+
252
+ static FAISS_ALWAYS_INLINE __m512
253
+ decode_16_components(const uint8_t* code, int i) {
254
+ // pure AVX512 implementation (not necessarily the fastest).
255
+ // see:
256
+ // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
257
+
258
+ // clang-format off
259
+
260
+ // 16 components, 16x6 bit=12 bytes
261
+ const __m128i bit_6v =
262
+ _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3);
263
+ const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v);
264
+
265
+ // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F
266
+ // 00 01 02 03
267
+ const __m256i shuffle_mask = _mm256_setr_epi16(
268
+ 0xFF00, 0x0100, 0x0201, 0xFF02,
269
+ 0xFF03, 0x0403, 0x0504, 0xFF05,
270
+ 0xFF06, 0x0706, 0x0807, 0xFF08,
271
+ 0xFF09, 0x0A09, 0x0B0A, 0xFF0B);
272
+ const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask);
273
+
274
+ // 0: xxxxxxxx xx543210
275
+ // 1: xxxx5432 10xxxxxx
276
+ // 2: xxxxxx54 3210xxxx
277
+ // 3: xxxxxxxx 543210xx
278
+ const __m256i shift_right_v = _mm256_setr_epi16(
279
+ 0x0U, 0x6U, 0x4U, 0x2U,
280
+ 0x0U, 0x6U, 0x4U, 0x2U,
281
+ 0x0U, 0x6U, 0x4U, 0x2U,
282
+ 0x0U, 0x6U, 0x4U, 0x2U);
283
+ __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v);
284
+
285
+ // remove unneeded bits
286
+ shuffled_shifted =
287
+ _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F));
288
+
289
+ // scale
290
+ const __m512 f8 =
291
+ _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted));
292
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
293
+ const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
294
+ return _mm512_fmadd_ps(f8, one_255, half_one_255);
295
+
296
+ // clang-format on
297
+ }
298
+
299
+ #elif defined(__AVX2__)
170
300
 
171
301
  /* Load 6 bytes that represent 8 6-bit values, return them as a
172
302
  * 8*32 bit vector register */
173
- static __m256i load6(const uint16_t* code16) {
303
+ static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) {
174
304
  const __m128i perm = _mm_set_epi8(
175
305
  -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
176
306
  const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0);
@@ -189,18 +319,44 @@ struct Codec6bit {
189
319
  return c5;
190
320
  }
191
321
 
192
- static __m256 decode_8_components(const uint8_t* code, int i) {
322
+ static FAISS_ALWAYS_INLINE __m256
323
+ decode_8_components(const uint8_t* code, int i) {
324
+ // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here
325
+ // // for the reference, maybe, it becomes used oned day.
326
+ // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3);
327
+ // const uint32_t* data32 = (const uint32_t*)data16;
328
+ // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32);
329
+ // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL);
330
+ // const __m128i i8 = _mm_set1_epi64x(vext);
331
+ // const __m256i i32 = _mm256_cvtepi8_epi32(i8);
332
+ // const __m256 f8 = _mm256_cvtepi32_ps(i32);
333
+ // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
334
+ // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
335
+ // return _mm256_fmadd_ps(f8, one_255, half_one_255);
336
+
193
337
  __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3));
194
338
  __m256 f8 = _mm256_cvtepi32_ps(i8);
195
339
  // this could also be done with bit manipulations but it is
196
340
  // not obviously faster
197
- __m256 half = _mm256_set1_ps(0.5f);
198
- f8 = _mm256_add_ps(f8, half);
199
- __m256 one_63 = _mm256_set1_ps(1.f / 63.f);
200
- return _mm256_mul_ps(f8, one_63);
341
+ const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
342
+ const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
343
+ return _mm256_fmadd_ps(f8, one_255, half_one_255);
201
344
  }
202
345
 
203
346
  #endif
347
+
348
+ #ifdef USE_NEON
349
+ static FAISS_ALWAYS_INLINE float32x4x2_t
350
+ decode_8_components(const uint8_t* code, int i) {
351
+ float32_t result[8] = {};
352
+ for (size_t j = 0; j < 8; j++) {
353
+ result[j] = decode_component(code, i + j);
354
+ }
355
+ float32x4_t res1 = vld1q_f32(result);
356
+ float32x4_t res2 = vld1q_f32(result + 4);
357
+ return {res1, res2};
358
+ }
359
+ #endif
204
360
  };
205
361
 
206
362
  /*******************************************************************
@@ -208,11 +364,14 @@ struct Codec6bit {
208
364
  * through a codec
209
365
  *******************************************************************/
210
366
 
211
- template <class Codec, bool uniform, int SIMD>
367
+ enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 };
368
+
369
+ template <class Codec, QuantizerTemplateScaling SCALING, int SIMD>
212
370
  struct QuantizerTemplate {};
213
371
 
214
372
  template <class Codec>
215
- struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
373
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>
374
+ : ScalarQuantizer::SQuantizer {
216
375
  const size_t d;
217
376
  const float vmin, vdiff;
218
377
 
@@ -242,31 +401,80 @@ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
242
401
  }
243
402
  }
244
403
 
245
- float reconstruct_component(const uint8_t* code, int i) const {
404
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
405
+ const {
246
406
  float xi = Codec::decode_component(code, i);
247
407
  return vmin + xi * vdiff;
248
408
  }
249
409
  };
250
410
 
251
- #ifdef __AVX2__
411
+ #if defined(__AVX512F__)
412
+
413
+ template <class Codec>
414
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 16>
415
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
416
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
417
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
418
+ d,
419
+ trained) {}
420
+
421
+ FAISS_ALWAYS_INLINE __m512
422
+ reconstruct_16_components(const uint8_t* code, int i) const {
423
+ __m512 xi = Codec::decode_16_components(code, i);
424
+ return _mm512_fmadd_ps(
425
+ xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin));
426
+ }
427
+ };
428
+
429
+ #elif defined(__AVX2__)
252
430
 
253
431
  template <class Codec>
254
- struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
432
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
433
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
255
434
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
256
- : QuantizerTemplate<Codec, true, 1>(d, trained) {}
435
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
436
+ d,
437
+ trained) {}
257
438
 
258
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
439
+ FAISS_ALWAYS_INLINE __m256
440
+ reconstruct_8_components(const uint8_t* code, int i) const {
259
441
  __m256 xi = Codec::decode_8_components(code, i);
260
- return _mm256_add_ps(
261
- _mm256_set1_ps(this->vmin),
262
- _mm256_mul_ps(xi, _mm256_set1_ps(this->vdiff)));
442
+ return _mm256_fmadd_ps(
443
+ xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin));
444
+ }
445
+ };
446
+
447
+ #endif
448
+
449
+ #ifdef USE_NEON
450
+
451
+ template <class Codec>
452
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
453
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
454
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
455
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
456
+ d,
457
+ trained) {}
458
+
459
+ FAISS_ALWAYS_INLINE float32x4x2_t
460
+ reconstruct_8_components(const uint8_t* code, int i) const {
461
+ float32x4x2_t xi = Codec::decode_8_components(code, i);
462
+ return {vfmaq_f32(
463
+ vdupq_n_f32(this->vmin),
464
+ xi.val[0],
465
+ vdupq_n_f32(this->vdiff)),
466
+ vfmaq_f32(
467
+ vdupq_n_f32(this->vmin),
468
+ xi.val[1],
469
+ vdupq_n_f32(this->vdiff))};
263
470
  }
264
471
  };
265
472
 
266
473
  #endif
267
474
 
268
475
  template <class Codec>
269
- struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
476
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1>
477
+ : ScalarQuantizer::SQuantizer {
270
478
  const size_t d;
271
479
  const float *vmin, *vdiff;
272
480
 
@@ -296,24 +504,77 @@ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
296
504
  }
297
505
  }
298
506
 
299
- float reconstruct_component(const uint8_t* code, int i) const {
507
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
508
+ const {
300
509
  float xi = Codec::decode_component(code, i);
301
510
  return vmin[i] + xi * vdiff[i];
302
511
  }
303
512
  };
304
513
 
305
- #ifdef __AVX2__
514
+ #if defined(__AVX512F__)
515
+
516
+ template <class Codec>
517
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 16>
518
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
519
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
520
+ : QuantizerTemplate<
521
+ Codec,
522
+ QuantizerTemplateScaling::NON_UNIFORM,
523
+ 1>(d, trained) {}
524
+
525
+ FAISS_ALWAYS_INLINE __m512
526
+ reconstruct_16_components(const uint8_t* code, int i) const {
527
+ __m512 xi = Codec::decode_16_components(code, i);
528
+ return _mm512_fmadd_ps(
529
+ xi,
530
+ _mm512_loadu_ps(this->vdiff + i),
531
+ _mm512_loadu_ps(this->vmin + i));
532
+ }
533
+ };
534
+
535
+ #elif defined(__AVX2__)
306
536
 
307
537
  template <class Codec>
308
- struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
538
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
539
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
309
540
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
310
- : QuantizerTemplate<Codec, false, 1>(d, trained) {}
541
+ : QuantizerTemplate<
542
+ Codec,
543
+ QuantizerTemplateScaling::NON_UNIFORM,
544
+ 1>(d, trained) {}
311
545
 
312
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
546
+ FAISS_ALWAYS_INLINE __m256
547
+ reconstruct_8_components(const uint8_t* code, int i) const {
313
548
  __m256 xi = Codec::decode_8_components(code, i);
314
- return _mm256_add_ps(
315
- _mm256_loadu_ps(this->vmin + i),
316
- _mm256_mul_ps(xi, _mm256_loadu_ps(this->vdiff + i)));
549
+ return _mm256_fmadd_ps(
550
+ xi,
551
+ _mm256_loadu_ps(this->vdiff + i),
552
+ _mm256_loadu_ps(this->vmin + i));
553
+ }
554
+ };
555
+
556
+ #endif
557
+
558
+ #ifdef USE_NEON
559
+
560
+ template <class Codec>
561
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
562
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
563
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
564
+ : QuantizerTemplate<
565
+ Codec,
566
+ QuantizerTemplateScaling::NON_UNIFORM,
567
+ 1>(d, trained) {}
568
+
569
+ FAISS_ALWAYS_INLINE float32x4x2_t
570
+ reconstruct_8_components(const uint8_t* code, int i) const {
571
+ float32x4x2_t xi = Codec::decode_8_components(code, i);
572
+
573
+ float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i);
574
+ float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i);
575
+
576
+ return {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
577
+ vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])};
317
578
  }
318
579
  };
319
580
 
@@ -344,19 +605,37 @@ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
344
605
  }
345
606
  }
346
607
 
347
- float reconstruct_component(const uint8_t* code, int i) const {
608
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
609
+ const {
348
610
  return decode_fp16(((uint16_t*)code)[i]);
349
611
  }
350
612
  };
351
613
 
352
- #ifdef USE_F16C
614
+ #if defined(USE_AVX512_F16C)
615
+
616
+ template <>
617
+ struct QuantizerFP16<16> : QuantizerFP16<1> {
618
+ QuantizerFP16(size_t d, const std::vector<float>& trained)
619
+ : QuantizerFP16<1>(d, trained) {}
620
+
621
+ FAISS_ALWAYS_INLINE __m512
622
+ reconstruct_16_components(const uint8_t* code, int i) const {
623
+ __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
624
+ return _mm512_cvtph_ps(codei);
625
+ }
626
+ };
627
+
628
+ #endif
629
+
630
+ #if defined(USE_F16C)
353
631
 
354
632
  template <>
355
633
  struct QuantizerFP16<8> : QuantizerFP16<1> {
356
634
  QuantizerFP16(size_t d, const std::vector<float>& trained)
357
635
  : QuantizerFP16<1>(d, trained) {}
358
636
 
359
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
637
+ FAISS_ALWAYS_INLINE __m256
638
+ reconstruct_8_components(const uint8_t* code, int i) const {
360
639
  __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i));
361
640
  return _mm256_cvtph_ps(codei);
362
641
  }
@@ -364,6 +643,103 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
364
643
 
365
644
  #endif
366
645
 
646
+ #ifdef USE_NEON
647
+
648
+ template <>
649
+ struct QuantizerFP16<8> : QuantizerFP16<1> {
650
+ QuantizerFP16(size_t d, const std::vector<float>& trained)
651
+ : QuantizerFP16<1>(d, trained) {}
652
+
653
+ FAISS_ALWAYS_INLINE float32x4x2_t
654
+ reconstruct_8_components(const uint8_t* code, int i) const {
655
+ uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
656
+ return {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
657
+ vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))};
658
+ }
659
+ };
660
+ #endif
661
+
662
+ /*******************************************************************
663
+ * BF16 quantizer
664
+ *******************************************************************/
665
+
666
+ template <int SIMDWIDTH>
667
+ struct QuantizerBF16 {};
668
+
669
+ template <>
670
+ struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer {
671
+ const size_t d;
672
+
673
+ QuantizerBF16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
674
+
675
+ void encode_vector(const float* x, uint8_t* code) const final {
676
+ for (size_t i = 0; i < d; i++) {
677
+ ((uint16_t*)code)[i] = encode_bf16(x[i]);
678
+ }
679
+ }
680
+
681
+ void decode_vector(const uint8_t* code, float* x) const final {
682
+ for (size_t i = 0; i < d; i++) {
683
+ x[i] = decode_bf16(((uint16_t*)code)[i]);
684
+ }
685
+ }
686
+
687
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
688
+ const {
689
+ return decode_bf16(((uint16_t*)code)[i]);
690
+ }
691
+ };
692
+
693
+ #if defined(__AVX512F__)
694
+
695
+ template <>
696
+ struct QuantizerBF16<16> : QuantizerBF16<1> {
697
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
698
+ : QuantizerBF16<1>(d, trained) {}
699
+ FAISS_ALWAYS_INLINE __m512
700
+ reconstruct_16_components(const uint8_t* code, int i) const {
701
+ __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
702
+ __m512i code_512i = _mm512_cvtepu16_epi32(code_256i);
703
+ code_512i = _mm512_slli_epi32(code_512i, 16);
704
+ return _mm512_castsi512_ps(code_512i);
705
+ }
706
+ };
707
+
708
+ #elif defined(__AVX2__)
709
+
710
+ template <>
711
+ struct QuantizerBF16<8> : QuantizerBF16<1> {
712
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
713
+ : QuantizerBF16<1>(d, trained) {}
714
+
715
+ FAISS_ALWAYS_INLINE __m256
716
+ reconstruct_8_components(const uint8_t* code, int i) const {
717
+ __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i));
718
+ __m256i code_256i = _mm256_cvtepu16_epi32(code_128i);
719
+ code_256i = _mm256_slli_epi32(code_256i, 16);
720
+ return _mm256_castsi256_ps(code_256i);
721
+ }
722
+ };
723
+
724
+ #endif
725
+
726
+ #ifdef USE_NEON
727
+
728
+ template <>
729
+ struct QuantizerBF16<8> : QuantizerBF16<1> {
730
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
731
+ : QuantizerBF16<1>(d, trained) {}
732
+
733
+ FAISS_ALWAYS_INLINE float32x4x2_t
734
+ reconstruct_8_components(const uint8_t* code, int i) const {
735
+ uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
736
+ return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)),
737
+ vreinterpretq_f32_u32(
738
+ vshlq_n_u32(vmovl_u16(codei.val[1]), 16))};
739
+ }
740
+ };
741
+ #endif
742
+
367
743
  /*******************************************************************
368
744
  * 8bit_direct quantizer
369
745
  *******************************************************************/
@@ -390,19 +766,36 @@ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
390
766
  }
391
767
  }
392
768
 
393
- float reconstruct_component(const uint8_t* code, int i) const {
769
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
770
+ const {
394
771
  return code[i];
395
772
  }
396
773
  };
397
774
 
398
- #ifdef __AVX2__
775
+ #if defined(__AVX512F__)
776
+
777
+ template <>
778
+ struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> {
779
+ Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
780
+ : Quantizer8bitDirect<1>(d, trained) {}
781
+
782
+ FAISS_ALWAYS_INLINE __m512
783
+ reconstruct_16_components(const uint8_t* code, int i) const {
784
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
785
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
786
+ return _mm512_cvtepi32_ps(y16); // 16 * float32
787
+ }
788
+ };
789
+
790
+ #elif defined(__AVX2__)
399
791
 
400
792
  template <>
401
793
  struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
402
794
  Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
403
795
  : Quantizer8bitDirect<1>(d, trained) {}
404
796
 
405
- __m256 reconstruct_8_components(const uint8_t* code, int i) const {
797
+ FAISS_ALWAYS_INLINE __m256
798
+ reconstruct_8_components(const uint8_t* code, int i) const {
406
799
  __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
407
800
  __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
408
801
  return _mm256_cvtepi32_ps(y8); // 8 * float32
@@ -411,6 +804,121 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
411
804
 
412
805
  #endif
413
806
 
807
+ #ifdef USE_NEON
808
+
809
+ template <>
810
+ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
811
+ Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
812
+ : Quantizer8bitDirect<1>(d, trained) {}
813
+
814
+ FAISS_ALWAYS_INLINE float32x4x2_t
815
+ reconstruct_8_components(const uint8_t* code, int i) const {
816
+ uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
817
+ uint16x8_t y8 = vmovl_u8(x8);
818
+ uint16x4_t y8_0 = vget_low_u16(y8);
819
+ uint16x4_t y8_1 = vget_high_u16(y8);
820
+
821
+ // convert uint16 -> uint32 -> fp32
822
+ return {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))};
823
+ }
824
+ };
825
+
826
+ #endif
827
+
828
+ /*******************************************************************
829
+ * 8bit_direct_signed quantizer
830
+ *******************************************************************/
831
+
832
+ template <int SIMDWIDTH>
833
+ struct Quantizer8bitDirectSigned {};
834
+
835
+ template <>
836
+ struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer {
837
+ const size_t d;
838
+
839
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& /* unused */)
840
+ : d(d) {}
841
+
842
+ void encode_vector(const float* x, uint8_t* code) const final {
843
+ for (size_t i = 0; i < d; i++) {
844
+ code[i] = (uint8_t)(x[i] + 128);
845
+ }
846
+ }
847
+
848
+ void decode_vector(const uint8_t* code, float* x) const final {
849
+ for (size_t i = 0; i < d; i++) {
850
+ x[i] = code[i] - 128;
851
+ }
852
+ }
853
+
854
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
855
+ const {
856
+ return code[i] - 128;
857
+ }
858
+ };
859
+
860
+ #if defined(__AVX512F__)
861
+
862
+ template <>
863
+ struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> {
864
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
865
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
866
+
867
+ FAISS_ALWAYS_INLINE __m512
868
+ reconstruct_16_components(const uint8_t* code, int i) const {
869
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
870
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
871
+ __m512i c16 = _mm512_set1_epi32(128);
872
+ __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes
873
+ return _mm512_cvtepi32_ps(z16); // 16 * float32
874
+ }
875
+ };
876
+
877
+ #elif defined(__AVX2__)
878
+
879
+ template <>
880
+ struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
881
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
882
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
883
+
884
+ FAISS_ALWAYS_INLINE __m256
885
+ reconstruct_8_components(const uint8_t* code, int i) const {
886
+ __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
887
+ __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
888
+ __m256i c8 = _mm256_set1_epi32(128);
889
+ __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes
890
+ return _mm256_cvtepi32_ps(z8); // 8 * float32
891
+ }
892
+ };
893
+
894
+ #endif
895
+
896
+ #ifdef USE_NEON
897
+
898
+ template <>
899
+ struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
900
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
901
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
902
+
903
+ FAISS_ALWAYS_INLINE float32x4x2_t
904
+ reconstruct_8_components(const uint8_t* code, int i) const {
905
+ uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
906
+ uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16
907
+ uint16x4_t y8_0 = vget_low_u16(y8);
908
+ uint16x4_t y8_1 = vget_high_u16(y8);
909
+
910
+ float32x4_t z8_0 = vcvtq_f32_u32(
911
+ vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32
912
+ float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1));
913
+
914
+ // subtract 128 to convert into signed numbers
915
+ return {vsubq_f32(z8_0, vmovq_n_f32(128.0)),
916
+ vsubq_f32(z8_1, vmovq_n_f32(128.0))};
917
+ }
918
+ };
919
+
920
+ #endif
921
+
414
922
  template <int SIMDWIDTH>
415
923
  ScalarQuantizer::SQuantizer* select_quantizer_1(
416
924
  QuantizerType qtype,
@@ -418,24 +926,38 @@ ScalarQuantizer::SQuantizer* select_quantizer_1(
418
926
  const std::vector<float>& trained) {
419
927
  switch (qtype) {
420
928
  case ScalarQuantizer::QT_8bit:
421
- return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(
422
- d, trained);
929
+ return new QuantizerTemplate<
930
+ Codec8bit,
931
+ QuantizerTemplateScaling::NON_UNIFORM,
932
+ SIMDWIDTH>(d, trained);
423
933
  case ScalarQuantizer::QT_6bit:
424
- return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(
425
- d, trained);
934
+ return new QuantizerTemplate<
935
+ Codec6bit,
936
+ QuantizerTemplateScaling::NON_UNIFORM,
937
+ SIMDWIDTH>(d, trained);
426
938
  case ScalarQuantizer::QT_4bit:
427
- return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(
428
- d, trained);
939
+ return new QuantizerTemplate<
940
+ Codec4bit,
941
+ QuantizerTemplateScaling::NON_UNIFORM,
942
+ SIMDWIDTH>(d, trained);
429
943
  case ScalarQuantizer::QT_8bit_uniform:
430
- return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(
431
- d, trained);
944
+ return new QuantizerTemplate<
945
+ Codec8bit,
946
+ QuantizerTemplateScaling::UNIFORM,
947
+ SIMDWIDTH>(d, trained);
432
948
  case ScalarQuantizer::QT_4bit_uniform:
433
- return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(
434
- d, trained);
949
+ return new QuantizerTemplate<
950
+ Codec4bit,
951
+ QuantizerTemplateScaling::UNIFORM,
952
+ SIMDWIDTH>(d, trained);
435
953
  case ScalarQuantizer::QT_fp16:
436
954
  return new QuantizerFP16<SIMDWIDTH>(d, trained);
955
+ case ScalarQuantizer::QT_bf16:
956
+ return new QuantizerBF16<SIMDWIDTH>(d, trained);
437
957
  case ScalarQuantizer::QT_8bit_direct:
438
958
  return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
959
+ case ScalarQuantizer::QT_8bit_direct_signed:
960
+ return new Quantizer8bitDirectSigned<SIMDWIDTH>(d, trained);
439
961
  }
440
962
  FAISS_THROW_MSG("unknown qtype");
441
963
  }
@@ -486,7 +1008,7 @@ void train_Uniform(
486
1008
  } else if (rs == ScalarQuantizer::RS_quantiles) {
487
1009
  std::vector<float> x_copy(n);
488
1010
  memcpy(x_copy.data(), x, n * sizeof(*x));
489
- // TODO just do a qucikselect
1011
+ // TODO just do a quickselect
490
1012
  std::sort(x_copy.begin(), x_copy.end());
491
1013
  int o = int(rs_arg * n);
492
1014
  if (o < 0)
@@ -632,27 +1154,63 @@ struct SimilarityL2<1> {
632
1154
 
633
1155
  float accu;
634
1156
 
635
- void begin() {
1157
+ FAISS_ALWAYS_INLINE void begin() {
636
1158
  accu = 0;
637
1159
  yi = y;
638
1160
  }
639
1161
 
640
- void add_component(float x) {
1162
+ FAISS_ALWAYS_INLINE void add_component(float x) {
641
1163
  float tmp = *yi++ - x;
642
1164
  accu += tmp * tmp;
643
1165
  }
644
1166
 
645
- void add_component_2(float x1, float x2) {
1167
+ FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) {
646
1168
  float tmp = x1 - x2;
647
1169
  accu += tmp * tmp;
648
1170
  }
649
1171
 
650
- float result() {
1172
+ FAISS_ALWAYS_INLINE float result() {
651
1173
  return accu;
652
1174
  }
653
1175
  };
654
1176
 
655
- #ifdef __AVX2__
1177
+ #if defined(__AVX512F__)
1178
+
1179
+ template <>
1180
+ struct SimilarityL2<16> {
1181
+ static constexpr int simdwidth = 16;
1182
+ static constexpr MetricType metric_type = METRIC_L2;
1183
+
1184
+ const float *y, *yi;
1185
+
1186
+ explicit SimilarityL2(const float* y) : y(y) {}
1187
+ __m512 accu16;
1188
+
1189
+ FAISS_ALWAYS_INLINE void begin_16() {
1190
+ accu16 = _mm512_setzero_ps();
1191
+ yi = y;
1192
+ }
1193
+
1194
+ FAISS_ALWAYS_INLINE void add_16_components(__m512 x) {
1195
+ __m512 yiv = _mm512_loadu_ps(yi);
1196
+ yi += 16;
1197
+ __m512 tmp = _mm512_sub_ps(yiv, x);
1198
+ accu16 = _mm512_fmadd_ps(tmp, tmp, accu16);
1199
+ }
1200
+
1201
+ FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x, __m512 y_2) {
1202
+ __m512 tmp = _mm512_sub_ps(y_2, x);
1203
+ accu16 = _mm512_fmadd_ps(tmp, tmp, accu16);
1204
+ }
1205
+
1206
+ FAISS_ALWAYS_INLINE float result_16() {
1207
+ // performs better than dividing into _mm256 and adding
1208
+ return _mm512_reduce_add_ps(accu16);
1209
+ }
1210
+ };
1211
+
1212
+ #elif defined(__AVX2__)
1213
+
656
1214
  template <>
657
1215
  struct SimilarityL2<8> {
658
1216
  static constexpr int simdwidth = 8;
@@ -663,34 +1221,87 @@ struct SimilarityL2<8> {
663
1221
  explicit SimilarityL2(const float* y) : y(y) {}
664
1222
  __m256 accu8;
665
1223
 
666
- void begin_8() {
1224
+ FAISS_ALWAYS_INLINE void begin_8() {
667
1225
  accu8 = _mm256_setzero_ps();
668
1226
  yi = y;
669
1227
  }
670
1228
 
671
- void add_8_components(__m256 x) {
1229
+ FAISS_ALWAYS_INLINE void add_8_components(__m256 x) {
672
1230
  __m256 yiv = _mm256_loadu_ps(yi);
673
1231
  yi += 8;
674
1232
  __m256 tmp = _mm256_sub_ps(yiv, x);
675
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
1233
+ accu8 = _mm256_fmadd_ps(tmp, tmp, accu8);
676
1234
  }
677
1235
 
678
- void add_8_components_2(__m256 x, __m256 y) {
679
- __m256 tmp = _mm256_sub_ps(y, x);
680
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
1236
+ FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x, __m256 y_2) {
1237
+ __m256 tmp = _mm256_sub_ps(y_2, x);
1238
+ accu8 = _mm256_fmadd_ps(tmp, tmp, accu8);
681
1239
  }
682
1240
 
683
- float result_8() {
684
- __m256 sum = _mm256_hadd_ps(accu8, accu8);
685
- __m256 sum2 = _mm256_hadd_ps(sum, sum);
686
- // now add the 0th and 4th component
687
- return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
688
- _mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
1241
+ FAISS_ALWAYS_INLINE float result_8() {
1242
+ const __m128 sum = _mm_add_ps(
1243
+ _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1));
1244
+ const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2));
1245
+ const __m128 v1 = _mm_add_ps(sum, v0);
1246
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
1247
+ const __m128 v3 = _mm_add_ps(v1, v2);
1248
+ return _mm_cvtss_f32(v3);
689
1249
  }
690
1250
  };
691
1251
 
692
1252
  #endif
693
1253
 
1254
+ #ifdef USE_NEON
1255
+ template <>
1256
+ struct SimilarityL2<8> {
1257
+ static constexpr int simdwidth = 8;
1258
+ static constexpr MetricType metric_type = METRIC_L2;
1259
+
1260
+ const float *y, *yi;
1261
+ explicit SimilarityL2(const float* y) : y(y) {}
1262
+ float32x4x2_t accu8;
1263
+
1264
+ FAISS_ALWAYS_INLINE void begin_8() {
1265
+ accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
1266
+ yi = y;
1267
+ }
1268
+
1269
+ FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
1270
+ float32x4x2_t yiv = vld1q_f32_x2(yi);
1271
+ yi += 8;
1272
+
1273
+ float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]);
1274
+ float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]);
1275
+
1276
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
1277
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
1278
+
1279
+ accu8 = {accu8_0, accu8_1};
1280
+ }
1281
+
1282
+ FAISS_ALWAYS_INLINE void add_8_components_2(
1283
+ float32x4x2_t x,
1284
+ float32x4x2_t y) {
1285
+ float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]);
1286
+ float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]);
1287
+
1288
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
1289
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
1290
+
1291
+ accu8 = {accu8_0, accu8_1};
1292
+ }
1293
+
1294
+ FAISS_ALWAYS_INLINE float result_8() {
1295
+ float32x4_t sum_0 = vpaddq_f32(accu8.val[0], accu8.val[0]);
1296
+ float32x4_t sum_1 = vpaddq_f32(accu8.val[1], accu8.val[1]);
1297
+
1298
+ float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0);
1299
+ float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1);
1300
+ return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0);
1301
+ }
1302
+ };
1303
+ #endif
1304
+
694
1305
  template <int SIMDWIDTH>
695
1306
  struct SimilarityIP {};
696
1307
 
@@ -704,25 +1315,61 @@ struct SimilarityIP<1> {
704
1315
 
705
1316
  explicit SimilarityIP(const float* y) : y(y) {}
706
1317
 
707
- void begin() {
1318
+ FAISS_ALWAYS_INLINE void begin() {
708
1319
  accu = 0;
709
1320
  yi = y;
710
1321
  }
711
1322
 
712
- void add_component(float x) {
1323
+ FAISS_ALWAYS_INLINE void add_component(float x) {
713
1324
  accu += *yi++ * x;
714
1325
  }
715
1326
 
716
- void add_component_2(float x1, float x2) {
1327
+ FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) {
717
1328
  accu += x1 * x2;
718
1329
  }
719
1330
 
720
- float result() {
1331
+ FAISS_ALWAYS_INLINE float result() {
721
1332
  return accu;
722
1333
  }
723
1334
  };
724
1335
 
725
- #ifdef __AVX2__
1336
+ #if defined(__AVX512F__)
1337
+
1338
+ template <>
1339
+ struct SimilarityIP<16> {
1340
+ static constexpr int simdwidth = 16;
1341
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
1342
+
1343
+ const float *y, *yi;
1344
+
1345
+ float accu;
1346
+
1347
+ explicit SimilarityIP(const float* y) : y(y) {}
1348
+
1349
+ __m512 accu16;
1350
+
1351
+ FAISS_ALWAYS_INLINE void begin_16() {
1352
+ accu16 = _mm512_setzero_ps();
1353
+ yi = y;
1354
+ }
1355
+
1356
+ FAISS_ALWAYS_INLINE void add_16_components(__m512 x) {
1357
+ __m512 yiv = _mm512_loadu_ps(yi);
1358
+ yi += 16;
1359
+ accu16 = _mm512_fmadd_ps(yiv, x, accu16);
1360
+ }
1361
+
1362
+ FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x1, __m512 x2) {
1363
+ accu16 = _mm512_fmadd_ps(x1, x2, accu16);
1364
+ }
1365
+
1366
+ FAISS_ALWAYS_INLINE float result_16() {
1367
+ // performs better than dividing into _mm256 and adding
1368
+ return _mm512_reduce_add_ps(accu16);
1369
+ }
1370
+ };
1371
+
1372
+ #elif defined(__AVX2__)
726
1373
 
727
1374
  template <>
728
1375
  struct SimilarityIP<8> {
@@ -737,27 +1384,76 @@ struct SimilarityIP<8> {
737
1384
 
738
1385
  __m256 accu8;
739
1386
 
740
- void begin_8() {
1387
+ FAISS_ALWAYS_INLINE void begin_8() {
741
1388
  accu8 = _mm256_setzero_ps();
742
1389
  yi = y;
743
1390
  }
744
1391
 
745
- void add_8_components(__m256 x) {
1392
+ FAISS_ALWAYS_INLINE void add_8_components(__m256 x) {
746
1393
  __m256 yiv = _mm256_loadu_ps(yi);
747
1394
  yi += 8;
748
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(yiv, x));
1395
+ accu8 = _mm256_fmadd_ps(yiv, x, accu8);
1396
+ }
1397
+
1398
+ FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x1, __m256 x2) {
1399
+ accu8 = _mm256_fmadd_ps(x1, x2, accu8);
1400
+ }
1401
+
1402
+ FAISS_ALWAYS_INLINE float result_8() {
1403
+ const __m128 sum = _mm_add_ps(
1404
+ _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1));
1405
+ const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2));
1406
+ const __m128 v1 = _mm_add_ps(sum, v0);
1407
+ __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
1408
+ const __m128 v3 = _mm_add_ps(v1, v2);
1409
+ return _mm_cvtss_f32(v3);
1410
+ }
1411
+ };
1412
+ #endif
1413
+
1414
+ #ifdef USE_NEON
1415
+
1416
+ template <>
1417
+ struct SimilarityIP<8> {
1418
+ static constexpr int simdwidth = 8;
1419
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
1420
+
1421
+ const float *y, *yi;
1422
+
1423
+ explicit SimilarityIP(const float* y) : y(y) {}
1424
+ float32x4x2_t accu8;
1425
+
1426
+ FAISS_ALWAYS_INLINE void begin_8() {
1427
+ accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
1428
+ yi = y;
1429
+ }
1430
+
1431
+ FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
1432
+ float32x4x2_t yiv = vld1q_f32_x2(yi);
1433
+ yi += 8;
1434
+
1435
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]);
1436
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]);
1437
+ accu8 = {accu8_0, accu8_1};
749
1438
  }
750
1439
 
751
- void add_8_components_2(__m256 x1, __m256 x2) {
752
- accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(x1, x2));
1440
+ FAISS_ALWAYS_INLINE void add_8_components_2(
1441
+ float32x4x2_t x1,
1442
+ float32x4x2_t x2) {
1443
+ float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]);
1444
+ float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]);
1445
+ accu8 = {accu8_0, accu8_1};
753
1446
  }
754
1447
 
755
- float result_8() {
756
- __m256 sum = _mm256_hadd_ps(accu8, accu8);
757
- __m256 sum2 = _mm256_hadd_ps(sum, sum);
758
- // now add the 0th and 4th component
759
- return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
760
- _mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
1448
+ FAISS_ALWAYS_INLINE float result_8() {
1449
+ float32x4x2_t sum = {
1450
+ vpaddq_f32(accu8.val[0], accu8.val[0]),
1451
+ vpaddq_f32(accu8.val[1], accu8.val[1])};
1452
+
1453
+ float32x4x2_t sum2 = {
1454
+ vpaddq_f32(sum.val[0], sum.val[0]),
1455
+ vpaddq_f32(sum.val[1], sum.val[1])};
1456
+ return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
761
1457
  }
762
1458
  };
763
1459
  #endif
@@ -815,7 +1511,55 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
815
1511
  }
816
1512
  };
817
1513
 
818
- #ifdef USE_F16C
1514
+ #if defined(USE_AVX512_F16C)
1515
+
1516
+ template <class Quantizer, class Similarity>
1517
+ struct DCTemplate<Quantizer, Similarity, 16>
1518
+ : SQDistanceComputer { // Update to handle 16 lanes
1519
+ using Sim = Similarity;
1520
+
1521
+ Quantizer quant;
1522
+
1523
+ DCTemplate(size_t d, const std::vector<float>& trained)
1524
+ : quant(d, trained) {}
1525
+
1526
+ float compute_distance(const float* x, const uint8_t* code) const {
1527
+ Similarity sim(x);
1528
+ sim.begin_16();
1529
+ for (size_t i = 0; i < quant.d; i += 16) {
1530
+ __m512 xi = quant.reconstruct_16_components(code, i);
1531
+ sim.add_16_components(xi);
1532
+ }
1533
+ return sim.result_16();
1534
+ }
1535
+
1536
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1537
+ const {
1538
+ Similarity sim(nullptr);
1539
+ sim.begin_16();
1540
+ for (size_t i = 0; i < quant.d; i += 16) {
1541
+ __m512 x1 = quant.reconstruct_16_components(code1, i);
1542
+ __m512 x2 = quant.reconstruct_16_components(code2, i);
1543
+ sim.add_16_components_2(x1, x2);
1544
+ }
1545
+ return sim.result_16();
1546
+ }
1547
+
1548
+ void set_query(const float* x) final {
1549
+ q = x;
1550
+ }
1551
+
1552
+ float symmetric_dis(idx_t i, idx_t j) override {
1553
+ return compute_code_distance(
1554
+ codes + i * code_size, codes + j * code_size);
1555
+ }
1556
+
1557
+ float query_to_code(const uint8_t* code) const final {
1558
+ return compute_distance(q, code);
1559
+ }
1560
+ };
1561
+
1562
+ #elif defined(USE_F16C)
819
1563
 
820
1564
  template <class Quantizer, class Similarity>
821
1565
  struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
@@ -864,6 +1608,53 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
864
1608
 
865
1609
  #endif
866
1610
 
1611
+ #ifdef USE_NEON
1612
+
1613
+ template <class Quantizer, class Similarity>
1614
+ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
1615
+ using Sim = Similarity;
1616
+
1617
+ Quantizer quant;
1618
+
1619
+ DCTemplate(size_t d, const std::vector<float>& trained)
1620
+ : quant(d, trained) {}
1621
+ float compute_distance(const float* x, const uint8_t* code) const {
1622
+ Similarity sim(x);
1623
+ sim.begin_8();
1624
+ for (size_t i = 0; i < quant.d; i += 8) {
1625
+ float32x4x2_t xi = quant.reconstruct_8_components(code, i);
1626
+ sim.add_8_components(xi);
1627
+ }
1628
+ return sim.result_8();
1629
+ }
1630
+
1631
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1632
+ const {
1633
+ Similarity sim(nullptr);
1634
+ sim.begin_8();
1635
+ for (size_t i = 0; i < quant.d; i += 8) {
1636
+ float32x4x2_t x1 = quant.reconstruct_8_components(code1, i);
1637
+ float32x4x2_t x2 = quant.reconstruct_8_components(code2, i);
1638
+ sim.add_8_components_2(x1, x2);
1639
+ }
1640
+ return sim.result_8();
1641
+ }
1642
+
1643
+ void set_query(const float* x) final {
1644
+ q = x;
1645
+ }
1646
+
1647
+ float symmetric_dis(idx_t i, idx_t j) override {
1648
+ return compute_code_distance(
1649
+ codes + i * code_size, codes + j * code_size);
1650
+ }
1651
+
1652
+ float query_to_code(const uint8_t* code) const final {
1653
+ return compute_distance(q, code);
1654
+ }
1655
+ };
1656
+ #endif
1657
+
867
1658
  /*******************************************************************
868
1659
  * DistanceComputerByte: computes distances in the integer domain
869
1660
  *******************************************************************/
@@ -915,7 +1706,60 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
915
1706
  }
916
1707
  };
917
1708
 
918
- #ifdef __AVX2__
1709
+ #if defined(__AVX512F__)
1710
+
1711
+ template <class Similarity>
1712
+ struct DistanceComputerByte<Similarity, 16> : SQDistanceComputer {
1713
+ using Sim = Similarity;
1714
+
1715
+ int d;
1716
+ std::vector<uint8_t> tmp;
1717
+
1718
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1719
+
1720
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1721
+ const {
1722
+ __m512i accu = _mm512_setzero_si512();
1723
+ for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time
1724
+ __m512i c1 = _mm512_cvtepu8_epi16(
1725
+ _mm256_loadu_si256((__m256i*)(code1 + i)));
1726
+ __m512i c2 = _mm512_cvtepu8_epi16(
1727
+ _mm256_loadu_si256((__m256i*)(code2 + i)));
1728
+ __m512i prod32;
1729
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1730
+ prod32 = _mm512_madd_epi16(c1, c2);
1731
+ } else {
1732
+ __m512i diff = _mm512_sub_epi16(c1, c2);
1733
+ prod32 = _mm512_madd_epi16(diff, diff);
1734
+ }
1735
+ accu = _mm512_add_epi32(accu, prod32);
1736
+ }
1737
+ // Horizontally add elements of accu
1738
+ return _mm512_reduce_add_epi32(accu);
1739
+ }
1740
+
1741
+ void set_query(const float* x) final {
1742
+ for (int i = 0; i < d; i++) {
1743
+ tmp[i] = int(x[i]);
1744
+ }
1745
+ }
1746
+
1747
+ int compute_distance(const float* x, const uint8_t* code) {
1748
+ set_query(x);
1749
+ return compute_code_distance(tmp.data(), code);
1750
+ }
1751
+
1752
+ float symmetric_dis(idx_t i, idx_t j) override {
1753
+ return compute_code_distance(
1754
+ codes + i * code_size, codes + j * code_size);
1755
+ }
1756
+
1757
+ float query_to_code(const uint8_t* code) const final {
1758
+ return compute_code_distance(tmp.data(), code);
1759
+ }
1760
+ };
1761
+
1762
+ #elif defined(__AVX2__)
919
1763
 
920
1764
  template <class Similarity>
921
1765
  struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
@@ -980,6 +1824,54 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
980
1824
 
981
1825
  #endif
982
1826
 
1827
+ #ifdef USE_NEON
1828
+
1829
+ template <class Similarity>
1830
+ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1831
+ using Sim = Similarity;
1832
+
1833
+ int d;
1834
+ std::vector<uint8_t> tmp;
1835
+
1836
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1837
+
1838
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1839
+ const {
1840
+ int accu = 0;
1841
+ for (int i = 0; i < d; i++) {
1842
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1843
+ accu += int(code1[i]) * code2[i];
1844
+ } else {
1845
+ int diff = int(code1[i]) - code2[i];
1846
+ accu += diff * diff;
1847
+ }
1848
+ }
1849
+ return accu;
1850
+ }
1851
+
1852
+ void set_query(const float* x) final {
1853
+ for (int i = 0; i < d; i++) {
1854
+ tmp[i] = int(x[i]);
1855
+ }
1856
+ }
1857
+
1858
+ int compute_distance(const float* x, const uint8_t* code) {
1859
+ set_query(x);
1860
+ return compute_code_distance(tmp.data(), code);
1861
+ }
1862
+
1863
+ float symmetric_dis(idx_t i, idx_t j) override {
1864
+ return compute_code_distance(
1865
+ codes + i * code_size, codes + j * code_size);
1866
+ }
1867
+
1868
+ float query_to_code(const uint8_t* code) const final {
1869
+ return compute_code_distance(tmp.data(), code);
1870
+ }
1871
+ };
1872
+
1873
+ #endif
1874
+
983
1875
  /*******************************************************************
984
1876
  * select_distance_computer: runtime selection of template
985
1877
  * specialization
@@ -994,31 +1886,46 @@ SQDistanceComputer* select_distance_computer(
994
1886
  switch (qtype) {
995
1887
  case ScalarQuantizer::QT_8bit_uniform:
996
1888
  return new DCTemplate<
997
- QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1889
+ QuantizerTemplate<
1890
+ Codec8bit,
1891
+ QuantizerTemplateScaling::UNIFORM,
1892
+ SIMDWIDTH>,
998
1893
  Sim,
999
1894
  SIMDWIDTH>(d, trained);
1000
1895
 
1001
1896
  case ScalarQuantizer::QT_4bit_uniform:
1002
1897
  return new DCTemplate<
1003
- QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1898
+ QuantizerTemplate<
1899
+ Codec4bit,
1900
+ QuantizerTemplateScaling::UNIFORM,
1901
+ SIMDWIDTH>,
1004
1902
  Sim,
1005
1903
  SIMDWIDTH>(d, trained);
1006
1904
 
1007
1905
  case ScalarQuantizer::QT_8bit:
1008
1906
  return new DCTemplate<
1009
- QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1907
+ QuantizerTemplate<
1908
+ Codec8bit,
1909
+ QuantizerTemplateScaling::NON_UNIFORM,
1910
+ SIMDWIDTH>,
1010
1911
  Sim,
1011
1912
  SIMDWIDTH>(d, trained);
1012
1913
 
1013
1914
  case ScalarQuantizer::QT_6bit:
1014
1915
  return new DCTemplate<
1015
- QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1916
+ QuantizerTemplate<
1917
+ Codec6bit,
1918
+ QuantizerTemplateScaling::NON_UNIFORM,
1919
+ SIMDWIDTH>,
1016
1920
  Sim,
1017
1921
  SIMDWIDTH>(d, trained);
1018
1922
 
1019
1923
  case ScalarQuantizer::QT_4bit:
1020
1924
  return new DCTemplate<
1021
- QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1925
+ QuantizerTemplate<
1926
+ Codec4bit,
1927
+ QuantizerTemplateScaling::NON_UNIFORM,
1928
+ SIMDWIDTH>,
1022
1929
  Sim,
1023
1930
  SIMDWIDTH>(d, trained);
1024
1931
 
@@ -1026,15 +1933,31 @@ SQDistanceComputer* select_distance_computer(
1026
1933
  return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1027
1934
  d, trained);
1028
1935
 
1936
+ case ScalarQuantizer::QT_bf16:
1937
+ return new DCTemplate<QuantizerBF16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1938
+ d, trained);
1939
+
1029
1940
  case ScalarQuantizer::QT_8bit_direct:
1941
+ #if defined(__AVX512F__)
1942
+ if (d % 32 == 0) {
1943
+ return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1944
+ } else
1945
+ #elif defined(__AVX2__)
1030
1946
  if (d % 16 == 0) {
1031
1947
  return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1032
- } else {
1948
+ } else
1949
+ #endif
1950
+ {
1033
1951
  return new DCTemplate<
1034
1952
  Quantizer8bitDirect<SIMDWIDTH>,
1035
1953
  Sim,
1036
1954
  SIMDWIDTH>(d, trained);
1037
1955
  }
1956
+ case ScalarQuantizer::QT_8bit_direct_signed:
1957
+ return new DCTemplate<
1958
+ Quantizer8bitDirectSigned<SIMDWIDTH>,
1959
+ Sim,
1960
+ SIMDWIDTH>(d, trained);
1038
1961
  }
1039
1962
  FAISS_THROW_MSG("unknown qtype");
1040
1963
  return nullptr;
@@ -1058,6 +1981,7 @@ void ScalarQuantizer::set_derived_sizes() {
1058
1981
  case QT_8bit:
1059
1982
  case QT_8bit_uniform:
1060
1983
  case QT_8bit_direct:
1984
+ case QT_8bit_direct_signed:
1061
1985
  code_size = d;
1062
1986
  bits = 8;
1063
1987
  break;
@@ -1074,6 +1998,10 @@ void ScalarQuantizer::set_derived_sizes() {
1074
1998
  code_size = d * 2;
1075
1999
  bits = 16;
1076
2000
  break;
2001
+ case QT_bf16:
2002
+ code_size = d * 2;
2003
+ bits = 16;
2004
+ break;
1077
2005
  }
1078
2006
  }
1079
2007
 
@@ -1110,39 +2038,19 @@ void ScalarQuantizer::train(size_t n, const float* x) {
1110
2038
  break;
1111
2039
  case QT_fp16:
1112
2040
  case QT_8bit_direct:
2041
+ case QT_bf16:
2042
+ case QT_8bit_direct_signed:
1113
2043
  // no training necessary
1114
2044
  break;
1115
2045
  }
1116
2046
  }
1117
2047
 
1118
- void ScalarQuantizer::train_residual(
1119
- size_t n,
1120
- const float* x,
1121
- Index* quantizer,
1122
- bool by_residual,
1123
- bool verbose) {
1124
- const float* x_in = x;
1125
-
1126
- // 100k points more than enough
1127
- x = fvecs_maybe_subsample(d, (size_t*)&n, 100000, x, verbose, 1234);
1128
-
1129
- ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
1130
-
1131
- if (by_residual) {
1132
- std::vector<idx_t> idx(n);
1133
- quantizer->assign(n, x, idx.data());
1134
-
1135
- std::vector<float> residuals(n * d);
1136
- quantizer->compute_residual_n(n, x, residuals.data(), idx.data());
1137
-
1138
- train(n, residuals.data());
1139
- } else {
1140
- train(n, x);
1141
- }
1142
- }
1143
-
1144
2048
  ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1145
- #ifdef USE_F16C
2049
+ #if defined(USE_AVX512_F16C)
2050
+ if (d % 16 == 0) {
2051
+ return select_quantizer_1<16>(qtype, d, trained);
2052
+ } else
2053
+ #elif defined(USE_F16C) || defined(USE_NEON)
1146
2054
  if (d % 8 == 0) {
1147
2055
  return select_quantizer_1<8>(qtype, d, trained);
1148
2056
  } else
@@ -1173,7 +2081,17 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1173
2081
  SQDistanceComputer* ScalarQuantizer::get_distance_computer(
1174
2082
  MetricType metric) const {
1175
2083
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1176
- #ifdef USE_F16C
2084
+ #if defined(USE_AVX512_F16C)
2085
+ if (d % 16 == 0) {
2086
+ if (metric == METRIC_L2) {
2087
+ return select_distance_computer<SimilarityL2<16>>(
2088
+ qtype, d, trained);
2089
+ } else {
2090
+ return select_distance_computer<SimilarityIP<16>>(
2091
+ qtype, d, trained);
2092
+ }
2093
+ } else
2094
+ #elif defined(USE_F16C) || defined(USE_NEON)
1177
2095
  if (d % 8 == 0) {
1178
2096
  if (metric == METRIC_L2) {
1179
2097
  return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
@@ -1204,7 +2122,6 @@ template <class DCClass, int use_sel>
1204
2122
  struct IVFSQScannerIP : InvertedListScanner {
1205
2123
  DCClass dc;
1206
2124
  bool by_residual;
1207
- const IDSelector* sel;
1208
2125
 
1209
2126
  float accu0; /// added to all distances
1210
2127
 
@@ -1215,9 +2132,11 @@ struct IVFSQScannerIP : InvertedListScanner {
1215
2132
  bool store_pairs,
1216
2133
  const IDSelector* sel,
1217
2134
  bool by_residual)
1218
- : dc(d, trained), by_residual(by_residual), sel(sel), accu0(0) {
2135
+ : dc(d, trained), by_residual(by_residual), accu0(0) {
1219
2136
  this->store_pairs = store_pairs;
2137
+ this->sel = sel;
1220
2138
  this->code_size = code_size;
2139
+ this->keep_max = true;
1221
2140
  }
1222
2141
 
1223
2142
  void set_query(const float* query) override {
@@ -1288,7 +2207,6 @@ struct IVFSQScannerL2 : InvertedListScanner {
1288
2207
 
1289
2208
  bool by_residual;
1290
2209
  const Index* quantizer;
1291
- const IDSelector* sel;
1292
2210
  const float* x; /// current query
1293
2211
 
1294
2212
  std::vector<float> tmp;
@@ -1304,10 +2222,10 @@ struct IVFSQScannerL2 : InvertedListScanner {
1304
2222
  : dc(d, trained),
1305
2223
  by_residual(by_residual),
1306
2224
  quantizer(quantizer),
1307
- sel(sel),
1308
2225
  x(nullptr),
1309
2226
  tmp(d) {
1310
2227
  this->store_pairs = store_pairs;
2228
+ this->sel = sel;
1311
2229
  this->code_size = code_size;
1312
2230
  }
1313
2231
 
@@ -1422,7 +2340,7 @@ InvertedListScanner* sel2_InvertedListScanner(
1422
2340
  }
1423
2341
  }
1424
2342
 
1425
- template <class Similarity, class Codec, bool uniform>
2343
+ template <class Similarity, class Codec, QuantizerTemplateScaling SCALING>
1426
2344
  InvertedListScanner* sel12_InvertedListScanner(
1427
2345
  const ScalarQuantizer* sq,
1428
2346
  const Index* quantizer,
@@ -1430,7 +2348,7 @@ InvertedListScanner* sel12_InvertedListScanner(
1430
2348
  const IDSelector* sel,
1431
2349
  bool r) {
1432
2350
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1433
- using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
2351
+ using QuantizerClass = QuantizerTemplate<Codec, SCALING, SIMDWIDTH>;
1434
2352
  using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
1435
2353
  return sel2_InvertedListScanner<DCClass>(
1436
2354
  sq, quantizer, store_pairs, sel, r);
@@ -1446,36 +2364,70 @@ InvertedListScanner* sel1_InvertedListScanner(
1446
2364
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1447
2365
  switch (sq->qtype) {
1448
2366
  case ScalarQuantizer::QT_8bit_uniform:
1449
- return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
2367
+ return sel12_InvertedListScanner<
2368
+ Similarity,
2369
+ Codec8bit,
2370
+ QuantizerTemplateScaling::UNIFORM>(
1450
2371
  sq, quantizer, store_pairs, sel, r);
1451
2372
  case ScalarQuantizer::QT_4bit_uniform:
1452
- return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
2373
+ return sel12_InvertedListScanner<
2374
+ Similarity,
2375
+ Codec4bit,
2376
+ QuantizerTemplateScaling::UNIFORM>(
1453
2377
  sq, quantizer, store_pairs, sel, r);
1454
2378
  case ScalarQuantizer::QT_8bit:
1455
- return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
2379
+ return sel12_InvertedListScanner<
2380
+ Similarity,
2381
+ Codec8bit,
2382
+ QuantizerTemplateScaling::NON_UNIFORM>(
1456
2383
  sq, quantizer, store_pairs, sel, r);
1457
2384
  case ScalarQuantizer::QT_4bit:
1458
- return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
2385
+ return sel12_InvertedListScanner<
2386
+ Similarity,
2387
+ Codec4bit,
2388
+ QuantizerTemplateScaling::NON_UNIFORM>(
1459
2389
  sq, quantizer, store_pairs, sel, r);
1460
2390
  case ScalarQuantizer::QT_6bit:
1461
- return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
2391
+ return sel12_InvertedListScanner<
2392
+ Similarity,
2393
+ Codec6bit,
2394
+ QuantizerTemplateScaling::NON_UNIFORM>(
1462
2395
  sq, quantizer, store_pairs, sel, r);
1463
2396
  case ScalarQuantizer::QT_fp16:
1464
2397
  return sel2_InvertedListScanner<DCTemplate<
1465
2398
  QuantizerFP16<SIMDWIDTH>,
1466
2399
  Similarity,
1467
2400
  SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
2401
+ case ScalarQuantizer::QT_bf16:
2402
+ return sel2_InvertedListScanner<DCTemplate<
2403
+ QuantizerBF16<SIMDWIDTH>,
2404
+ Similarity,
2405
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1468
2406
  case ScalarQuantizer::QT_8bit_direct:
2407
+ #if defined(__AVX512F__)
2408
+ if (sq->d % 32 == 0) {
2409
+ return sel2_InvertedListScanner<
2410
+ DistanceComputerByte<Similarity, SIMDWIDTH>>(
2411
+ sq, quantizer, store_pairs, sel, r);
2412
+ } else
2413
+ #elif defined(__AVX2__)
1469
2414
  if (sq->d % 16 == 0) {
1470
2415
  return sel2_InvertedListScanner<
1471
2416
  DistanceComputerByte<Similarity, SIMDWIDTH>>(
1472
2417
  sq, quantizer, store_pairs, sel, r);
1473
- } else {
2418
+ } else
2419
+ #endif
2420
+ {
1474
2421
  return sel2_InvertedListScanner<DCTemplate<
1475
2422
  Quantizer8bitDirect<SIMDWIDTH>,
1476
2423
  Similarity,
1477
2424
  SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1478
2425
  }
2426
+ case ScalarQuantizer::QT_8bit_direct_signed:
2427
+ return sel2_InvertedListScanner<DCTemplate<
2428
+ Quantizer8bitDirectSigned<SIMDWIDTH>,
2429
+ Similarity,
2430
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1479
2431
  }
1480
2432
 
1481
2433
  FAISS_THROW_MSG("unknown qtype");
@@ -1509,7 +2461,12 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1509
2461
  bool store_pairs,
1510
2462
  const IDSelector* sel,
1511
2463
  bool by_residual) const {
1512
- #ifdef USE_F16C
2464
+ #if defined(USE_AVX512_F16C)
2465
+ if (d % 16 == 0) {
2466
+ return sel0_InvertedListScanner<16>(
2467
+ mt, this, quantizer, store_pairs, sel, by_residual);
2468
+ } else
2469
+ #elif defined(USE_F16C) || defined(USE_NEON)
1513
2470
  if (d % 8 == 0) {
1514
2471
  return sel0_InvertedListScanner<8>(
1515
2472
  mt, this, quantizer, store_pairs, sel, by_residual);