faiss 0.3.1 → 0.3.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (119) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.h +1 -1
  5. data/vendor/faiss/faiss/Clustering.cpp +35 -4
  6. data/vendor/faiss/faiss/Clustering.h +10 -1
  7. data/vendor/faiss/faiss/IVFlib.cpp +4 -1
  8. data/vendor/faiss/faiss/Index.h +21 -6
  9. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  10. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +1 -1
  11. data/vendor/faiss/faiss/IndexFastScan.cpp +22 -4
  12. data/vendor/faiss/faiss/IndexFlat.cpp +11 -7
  13. data/vendor/faiss/faiss/IndexFlatCodes.cpp +159 -5
  14. data/vendor/faiss/faiss/IndexFlatCodes.h +20 -3
  15. data/vendor/faiss/faiss/IndexHNSW.cpp +143 -90
  16. data/vendor/faiss/faiss/IndexHNSW.h +52 -3
  17. data/vendor/faiss/faiss/IndexIVF.cpp +3 -3
  18. data/vendor/faiss/faiss/IndexIVF.h +9 -1
  19. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +15 -0
  20. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -0
  21. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +130 -57
  22. data/vendor/faiss/faiss/IndexIVFFastScan.h +14 -7
  23. data/vendor/faiss/faiss/IndexIVFPQ.cpp +1 -3
  24. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +21 -2
  25. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  26. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  27. data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -29
  28. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  29. data/vendor/faiss/faiss/IndexNSG.h +1 -1
  30. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  31. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  32. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  33. data/vendor/faiss/faiss/IndexRefine.cpp +5 -5
  34. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +3 -1
  35. data/vendor/faiss/faiss/MetricType.h +7 -2
  36. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  37. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  38. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  39. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  40. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +36 -4
  41. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +6 -0
  42. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  43. data/vendor/faiss/faiss/gpu/GpuIndex.h +2 -8
  44. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  45. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +6 -0
  46. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +2 -0
  47. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +25 -0
  48. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  49. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +6 -0
  50. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  51. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +65 -0
  52. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  53. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  54. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  55. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  56. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +25 -0
  57. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +9 -1
  58. data/vendor/faiss/faiss/impl/DistanceComputer.h +46 -0
  59. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  60. data/vendor/faiss/faiss/impl/HNSW.cpp +358 -190
  61. data/vendor/faiss/faiss/impl/HNSW.h +43 -22
  62. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +8 -8
  63. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  64. data/vendor/faiss/faiss/impl/NNDescent.cpp +13 -8
  65. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  66. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +1 -0
  67. data/vendor/faiss/faiss/impl/ProductQuantizer.h +5 -1
  68. data/vendor/faiss/faiss/impl/ResultHandler.h +151 -32
  69. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +719 -102
  70. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -0
  71. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +5 -0
  72. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  73. data/vendor/faiss/faiss/impl/index_read.cpp +29 -15
  74. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  75. data/vendor/faiss/faiss/impl/index_write.cpp +28 -10
  76. data/vendor/faiss/faiss/impl/io.cpp +13 -5
  77. data/vendor/faiss/faiss/impl/io.h +4 -4
  78. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  79. data/vendor/faiss/faiss/impl/platform_macros.h +22 -0
  80. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +11 -0
  81. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +1 -1
  82. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +448 -1
  83. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +5 -5
  84. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +1 -1
  85. data/vendor/faiss/faiss/impl/simd_result_handlers.h +143 -59
  86. data/vendor/faiss/faiss/index_factory.cpp +31 -13
  87. data/vendor/faiss/faiss/index_io.h +12 -5
  88. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  89. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  90. data/vendor/faiss/faiss/invlists/DirectMap.cpp +9 -1
  91. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +55 -17
  92. data/vendor/faiss/faiss/invlists/InvertedLists.h +18 -9
  93. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +21 -6
  94. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  95. data/vendor/faiss/faiss/python/python_callbacks.cpp +3 -3
  96. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  97. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  98. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  99. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  100. data/vendor/faiss/faiss/utils/distances.cpp +58 -88
  101. data/vendor/faiss/faiss/utils/distances.h +5 -5
  102. data/vendor/faiss/faiss/utils/distances_simd.cpp +997 -9
  103. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  104. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  105. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  106. data/vendor/faiss/faiss/utils/hamming.cpp +1 -1
  107. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +4 -1
  108. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +2 -1
  109. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  110. data/vendor/faiss/faiss/utils/random.h +25 -0
  111. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  112. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  113. data/vendor/faiss/faiss/utils/simdlib_neon.h +5 -2
  114. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  115. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  116. data/vendor/faiss/faiss/utils/utils.cpp +10 -3
  117. data/vendor/faiss/faiss/utils/utils.h +3 -0
  118. metadata +16 -4
  119. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -23,6 +23,7 @@
23
23
  #include <faiss/impl/AuxIndexStructures.h>
24
24
  #include <faiss/impl/FaissAssert.h>
25
25
  #include <faiss/impl/IDSelector.h>
26
+ #include <faiss/utils/bf16.h>
26
27
  #include <faiss/utils/fp16.h>
27
28
  #include <faiss/utils/utils.h>
28
29
 
@@ -43,7 +44,9 @@ namespace faiss {
43
44
  * that hides the template mess.
44
45
  ********************************************************************/
45
46
 
46
- #ifdef __AVX2__
47
+ #if defined(__AVX512F__) && defined(__F16C__)
48
+ #define USE_AVX512_F16C
49
+ #elif defined(__AVX2__)
47
50
  #ifdef __F16C__
48
51
  #define USE_F16C
49
52
  #else
@@ -52,6 +55,15 @@ namespace faiss {
52
55
  #endif
53
56
  #endif
54
57
 
58
+ #if defined(__aarch64__)
59
+ #if defined(__GNUC__) && __GNUC__ < 8
60
+ #warning \
61
+ "Cannot enable NEON optimizations in scalar quantizer if the compiler is GCC<8"
62
+ #else
63
+ #define USE_NEON
64
+ #endif
65
+ #endif
66
+
55
67
  namespace {
56
68
 
57
69
  typedef ScalarQuantizer::QuantizerType QuantizerType;
@@ -78,7 +90,17 @@ struct Codec8bit {
78
90
  return (code[i] + 0.5f) / 255.0f;
79
91
  }
80
92
 
81
- #ifdef __AVX2__
93
+ #if defined(__AVX512F__)
94
+ static FAISS_ALWAYS_INLINE __m512
95
+ decode_16_components(const uint8_t* code, int i) {
96
+ const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i));
97
+ const __m512i i32 = _mm512_cvtepu8_epi32(c16);
98
+ const __m512 f16 = _mm512_cvtepi32_ps(i32);
99
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f);
100
+ const __m512 one_255 = _mm512_set1_ps(1.f / 255.f);
101
+ return _mm512_fmadd_ps(f16, one_255, half_one_255);
102
+ }
103
+ #elif defined(__AVX2__)
82
104
  static FAISS_ALWAYS_INLINE __m256
83
105
  decode_8_components(const uint8_t* code, int i) {
84
106
  const uint64_t c8 = *(uint64_t*)(code + i);
@@ -92,7 +114,7 @@ struct Codec8bit {
92
114
  }
93
115
  #endif
94
116
 
95
- #ifdef __aarch64__
117
+ #ifdef USE_NEON
96
118
  static FAISS_ALWAYS_INLINE float32x4x2_t
97
119
  decode_8_components(const uint8_t* code, int i) {
98
120
  float32_t result[8] = {};
@@ -101,8 +123,7 @@ struct Codec8bit {
101
123
  }
102
124
  float32x4_t res1 = vld1q_f32(result);
103
125
  float32x4_t res2 = vld1q_f32(result + 4);
104
- float32x4x2_t res = vzipq_f32(res1, res2);
105
- return vuzpq_f32(res.val[0], res.val[1]);
126
+ return {res1, res2};
106
127
  }
107
128
  #endif
108
129
  };
@@ -121,7 +142,26 @@ struct Codec4bit {
121
142
  return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
122
143
  }
123
144
 
124
- #ifdef __AVX2__
145
+ #if defined(__AVX512F__)
146
+ static FAISS_ALWAYS_INLINE __m512
147
+ decode_16_components(const uint8_t* code, int i) {
148
+ uint64_t c8 = *(uint64_t*)(code + (i >> 1));
149
+ uint64_t mask = 0x0f0f0f0f0f0f0f0f;
150
+ uint64_t c8ev = c8 & mask;
151
+ uint64_t c8od = (c8 >> 4) & mask;
152
+
153
+ __m128i c16 =
154
+ _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od));
155
+ __m256i c8lo = _mm256_cvtepu8_epi32(c16);
156
+ __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8));
157
+ __m512i i16 = _mm512_castsi256_si512(c8lo);
158
+ i16 = _mm512_inserti32x8(i16, c8hi, 1);
159
+ __m512 f16 = _mm512_cvtepi32_ps(i16);
160
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f);
161
+ const __m512 one_255 = _mm512_set1_ps(1.f / 15.f);
162
+ return _mm512_fmadd_ps(f16, one_255, half_one_255);
163
+ }
164
+ #elif defined(__AVX2__)
125
165
  static FAISS_ALWAYS_INLINE __m256
126
166
  decode_8_components(const uint8_t* code, int i) {
127
167
  uint32_t c4 = *(uint32_t*)(code + (i >> 1));
@@ -144,7 +184,7 @@ struct Codec4bit {
144
184
  }
145
185
  #endif
146
186
 
147
- #ifdef __aarch64__
187
+ #ifdef USE_NEON
148
188
  static FAISS_ALWAYS_INLINE float32x4x2_t
149
189
  decode_8_components(const uint8_t* code, int i) {
150
190
  float32_t result[8] = {};
@@ -153,8 +193,7 @@ struct Codec4bit {
153
193
  }
154
194
  float32x4_t res1 = vld1q_f32(result);
155
195
  float32x4_t res2 = vld1q_f32(result + 4);
156
- float32x4x2_t res = vzipq_f32(res1, res2);
157
- return vuzpq_f32(res.val[0], res.val[1]);
196
+ return {res1, res2};
158
197
  }
159
198
  #endif
160
199
  };
@@ -208,7 +247,56 @@ struct Codec6bit {
208
247
  return (bits + 0.5f) / 63.0f;
209
248
  }
210
249
 
211
- #ifdef __AVX2__
250
+ #if defined(__AVX512F__)
251
+
252
+ static FAISS_ALWAYS_INLINE __m512
253
+ decode_16_components(const uint8_t* code, int i) {
254
+ // pure AVX512 implementation (not necessarily the fastest).
255
+ // see:
256
+ // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
257
+
258
+ // clang-format off
259
+
260
+ // 16 components, 16x6 bit=12 bytes
261
+ const __m128i bit_6v =
262
+ _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3);
263
+ const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v);
264
+
265
+ // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F
266
+ // 00 01 02 03
267
+ const __m256i shuffle_mask = _mm256_setr_epi16(
268
+ 0xFF00, 0x0100, 0x0201, 0xFF02,
269
+ 0xFF03, 0x0403, 0x0504, 0xFF05,
270
+ 0xFF06, 0x0706, 0x0807, 0xFF08,
271
+ 0xFF09, 0x0A09, 0x0B0A, 0xFF0B);
272
+ const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask);
273
+
274
+ // 0: xxxxxxxx xx543210
275
+ // 1: xxxx5432 10xxxxxx
276
+ // 2: xxxxxx54 3210xxxx
277
+ // 3: xxxxxxxx 543210xx
278
+ const __m256i shift_right_v = _mm256_setr_epi16(
279
+ 0x0U, 0x6U, 0x4U, 0x2U,
280
+ 0x0U, 0x6U, 0x4U, 0x2U,
281
+ 0x0U, 0x6U, 0x4U, 0x2U,
282
+ 0x0U, 0x6U, 0x4U, 0x2U);
283
+ __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v);
284
+
285
+ // remove unneeded bits
286
+ shuffled_shifted =
287
+ _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F));
288
+
289
+ // scale
290
+ const __m512 f8 =
291
+ _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted));
292
+ const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
293
+ const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
294
+ return _mm512_fmadd_ps(f8, one_255, half_one_255);
295
+
296
+ // clang-format on
297
+ }
298
+
299
+ #elif defined(__AVX2__)
212
300
 
213
301
  /* Load 6 bytes that represent 8 6-bit values, return them as a
214
302
  * 8*32 bit vector register */
@@ -257,7 +345,7 @@ struct Codec6bit {
257
345
 
258
346
  #endif
259
347
 
260
- #ifdef __aarch64__
348
+ #ifdef USE_NEON
261
349
  static FAISS_ALWAYS_INLINE float32x4x2_t
262
350
  decode_8_components(const uint8_t* code, int i) {
263
351
  float32_t result[8] = {};
@@ -266,8 +354,7 @@ struct Codec6bit {
266
354
  }
267
355
  float32x4_t res1 = vld1q_f32(result);
268
356
  float32x4_t res2 = vld1q_f32(result + 4);
269
- float32x4x2_t res = vzipq_f32(res1, res2);
270
- return vuzpq_f32(res.val[0], res.val[1]);
357
+ return {res1, res2};
271
358
  }
272
359
  #endif
273
360
  };
@@ -277,11 +364,14 @@ struct Codec6bit {
277
364
  * through a codec
278
365
  *******************************************************************/
279
366
 
280
- template <class Codec, bool uniform, int SIMD>
367
+ enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 };
368
+
369
+ template <class Codec, QuantizerTemplateScaling SCALING, int SIMD>
281
370
  struct QuantizerTemplate {};
282
371
 
283
372
  template <class Codec>
284
- struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
373
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>
374
+ : ScalarQuantizer::SQuantizer {
285
375
  const size_t d;
286
376
  const float vmin, vdiff;
287
377
 
@@ -318,12 +408,33 @@ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::SQuantizer {
318
408
  }
319
409
  };
320
410
 
321
- #ifdef __AVX2__
411
+ #if defined(__AVX512F__)
412
+
413
+ template <class Codec>
414
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 16>
415
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
416
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
417
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
418
+ d,
419
+ trained) {}
420
+
421
+ FAISS_ALWAYS_INLINE __m512
422
+ reconstruct_16_components(const uint8_t* code, int i) const {
423
+ __m512 xi = Codec::decode_16_components(code, i);
424
+ return _mm512_fmadd_ps(
425
+ xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin));
426
+ }
427
+ };
428
+
429
+ #elif defined(__AVX2__)
322
430
 
323
431
  template <class Codec>
324
- struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
432
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
433
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
325
434
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
326
- : QuantizerTemplate<Codec, true, 1>(d, trained) {}
435
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
436
+ d,
437
+ trained) {}
327
438
 
328
439
  FAISS_ALWAYS_INLINE __m256
329
440
  reconstruct_8_components(const uint8_t* code, int i) const {
@@ -335,33 +446,35 @@ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
335
446
 
336
447
  #endif
337
448
 
338
- #ifdef __aarch64__
449
+ #ifdef USE_NEON
339
450
 
340
451
  template <class Codec>
341
- struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
452
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
453
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
342
454
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
343
- : QuantizerTemplate<Codec, true, 1>(d, trained) {}
455
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
456
+ d,
457
+ trained) {}
344
458
 
345
459
  FAISS_ALWAYS_INLINE float32x4x2_t
346
460
  reconstruct_8_components(const uint8_t* code, int i) const {
347
461
  float32x4x2_t xi = Codec::decode_8_components(code, i);
348
- float32x4x2_t res = vzipq_f32(
349
- vfmaq_f32(
462
+ return {vfmaq_f32(
350
463
  vdupq_n_f32(this->vmin),
351
464
  xi.val[0],
352
465
  vdupq_n_f32(this->vdiff)),
353
466
  vfmaq_f32(
354
467
  vdupq_n_f32(this->vmin),
355
468
  xi.val[1],
356
- vdupq_n_f32(this->vdiff)));
357
- return vuzpq_f32(res.val[0], res.val[1]);
469
+ vdupq_n_f32(this->vdiff))};
358
470
  }
359
471
  };
360
472
 
361
473
  #endif
362
474
 
363
475
  template <class Codec>
364
- struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
476
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1>
477
+ : ScalarQuantizer::SQuantizer {
365
478
  const size_t d;
366
479
  const float *vmin, *vdiff;
367
480
 
@@ -398,12 +511,37 @@ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::SQuantizer {
398
511
  }
399
512
  };
400
513
 
401
- #ifdef __AVX2__
514
+ #if defined(__AVX512F__)
515
+
516
+ template <class Codec>
517
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 16>
518
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
519
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
520
+ : QuantizerTemplate<
521
+ Codec,
522
+ QuantizerTemplateScaling::NON_UNIFORM,
523
+ 1>(d, trained) {}
524
+
525
+ FAISS_ALWAYS_INLINE __m512
526
+ reconstruct_16_components(const uint8_t* code, int i) const {
527
+ __m512 xi = Codec::decode_16_components(code, i);
528
+ return _mm512_fmadd_ps(
529
+ xi,
530
+ _mm512_loadu_ps(this->vdiff + i),
531
+ _mm512_loadu_ps(this->vmin + i));
532
+ }
533
+ };
534
+
535
+ #elif defined(__AVX2__)
402
536
 
403
537
  template <class Codec>
404
- struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
538
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
539
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
405
540
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
406
- : QuantizerTemplate<Codec, false, 1>(d, trained) {}
541
+ : QuantizerTemplate<
542
+ Codec,
543
+ QuantizerTemplateScaling::NON_UNIFORM,
544
+ 1>(d, trained) {}
407
545
 
408
546
  FAISS_ALWAYS_INLINE __m256
409
547
  reconstruct_8_components(const uint8_t* code, int i) const {
@@ -417,12 +555,16 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
417
555
 
418
556
  #endif
419
557
 
420
- #ifdef __aarch64__
558
+ #ifdef USE_NEON
421
559
 
422
560
  template <class Codec>
423
- struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
561
+ struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
562
+ : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
424
563
  QuantizerTemplate(size_t d, const std::vector<float>& trained)
425
- : QuantizerTemplate<Codec, false, 1>(d, trained) {}
564
+ : QuantizerTemplate<
565
+ Codec,
566
+ QuantizerTemplateScaling::NON_UNIFORM,
567
+ 1>(d, trained) {}
426
568
 
427
569
  FAISS_ALWAYS_INLINE float32x4x2_t
428
570
  reconstruct_8_components(const uint8_t* code, int i) const {
@@ -431,10 +573,8 @@ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
431
573
  float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i);
432
574
  float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i);
433
575
 
434
- float32x4x2_t res = vzipq_f32(
435
- vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
436
- vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1]));
437
- return vuzpq_f32(res.val[0], res.val[1]);
576
+ return {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
577
+ vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])};
438
578
  }
439
579
  };
440
580
 
@@ -471,7 +611,23 @@ struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
471
611
  }
472
612
  };
473
613
 
474
- #ifdef USE_F16C
614
+ #if defined(USE_AVX512_F16C)
615
+
616
+ template <>
617
+ struct QuantizerFP16<16> : QuantizerFP16<1> {
618
+ QuantizerFP16(size_t d, const std::vector<float>& trained)
619
+ : QuantizerFP16<1>(d, trained) {}
620
+
621
+ FAISS_ALWAYS_INLINE __m512
622
+ reconstruct_16_components(const uint8_t* code, int i) const {
623
+ __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
624
+ return _mm512_cvtph_ps(codei);
625
+ }
626
+ };
627
+
628
+ #endif
629
+
630
+ #if defined(USE_F16C)
475
631
 
476
632
  template <>
477
633
  struct QuantizerFP16<8> : QuantizerFP16<1> {
@@ -487,7 +643,7 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
487
643
 
488
644
  #endif
489
645
 
490
- #ifdef __aarch64__
646
+ #ifdef USE_NEON
491
647
 
492
648
  template <>
493
649
  struct QuantizerFP16<8> : QuantizerFP16<1> {
@@ -496,10 +652,90 @@ struct QuantizerFP16<8> : QuantizerFP16<1> {
496
652
 
497
653
  FAISS_ALWAYS_INLINE float32x4x2_t
498
654
  reconstruct_8_components(const uint8_t* code, int i) const {
499
- uint16x4x2_t codei = vld2_u16((const uint16_t*)(code + 2 * i));
500
- return vzipq_f32(
501
- vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
502
- vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1])));
655
+ uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
656
+ return {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
657
+ vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))};
658
+ }
659
+ };
660
+ #endif
661
+
662
+ /*******************************************************************
663
+ * BF16 quantizer
664
+ *******************************************************************/
665
+
666
+ template <int SIMDWIDTH>
667
+ struct QuantizerBF16 {};
668
+
669
+ template <>
670
+ struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer {
671
+ const size_t d;
672
+
673
+ QuantizerBF16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
674
+
675
+ void encode_vector(const float* x, uint8_t* code) const final {
676
+ for (size_t i = 0; i < d; i++) {
677
+ ((uint16_t*)code)[i] = encode_bf16(x[i]);
678
+ }
679
+ }
680
+
681
+ void decode_vector(const uint8_t* code, float* x) const final {
682
+ for (size_t i = 0; i < d; i++) {
683
+ x[i] = decode_bf16(((uint16_t*)code)[i]);
684
+ }
685
+ }
686
+
687
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
688
+ const {
689
+ return decode_bf16(((uint16_t*)code)[i]);
690
+ }
691
+ };
692
+
693
+ #if defined(__AVX512F__)
694
+
695
+ template <>
696
+ struct QuantizerBF16<16> : QuantizerBF16<1> {
697
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
698
+ : QuantizerBF16<1>(d, trained) {}
699
+ FAISS_ALWAYS_INLINE __m512
700
+ reconstruct_16_components(const uint8_t* code, int i) const {
701
+ __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
702
+ __m512i code_512i = _mm512_cvtepu16_epi32(code_256i);
703
+ code_512i = _mm512_slli_epi32(code_512i, 16);
704
+ return _mm512_castsi512_ps(code_512i);
705
+ }
706
+ };
707
+
708
+ #elif defined(__AVX2__)
709
+
710
+ template <>
711
+ struct QuantizerBF16<8> : QuantizerBF16<1> {
712
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
713
+ : QuantizerBF16<1>(d, trained) {}
714
+
715
+ FAISS_ALWAYS_INLINE __m256
716
+ reconstruct_8_components(const uint8_t* code, int i) const {
717
+ __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i));
718
+ __m256i code_256i = _mm256_cvtepu16_epi32(code_128i);
719
+ code_256i = _mm256_slli_epi32(code_256i, 16);
720
+ return _mm256_castsi256_ps(code_256i);
721
+ }
722
+ };
723
+
724
+ #endif
725
+
726
+ #ifdef USE_NEON
727
+
728
+ template <>
729
+ struct QuantizerBF16<8> : QuantizerBF16<1> {
730
+ QuantizerBF16(size_t d, const std::vector<float>& trained)
731
+ : QuantizerBF16<1>(d, trained) {}
732
+
733
+ FAISS_ALWAYS_INLINE float32x4x2_t
734
+ reconstruct_8_components(const uint8_t* code, int i) const {
735
+ uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
736
+ return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)),
737
+ vreinterpretq_f32_u32(
738
+ vshlq_n_u32(vmovl_u16(codei.val[1]), 16))};
503
739
  }
504
740
  };
505
741
  #endif
@@ -536,7 +772,22 @@ struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
536
772
  }
537
773
  };
538
774
 
539
- #ifdef __AVX2__
775
+ #if defined(__AVX512F__)
776
+
777
+ template <>
778
+ struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> {
779
+ Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
780
+ : Quantizer8bitDirect<1>(d, trained) {}
781
+
782
+ FAISS_ALWAYS_INLINE __m512
783
+ reconstruct_16_components(const uint8_t* code, int i) const {
784
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
785
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
786
+ return _mm512_cvtepi32_ps(y16); // 16 * float32
787
+ }
788
+ };
789
+
790
+ #elif defined(__AVX2__)
540
791
 
541
792
  template <>
542
793
  struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
@@ -553,7 +804,7 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
553
804
 
554
805
  #endif
555
806
 
556
- #ifdef __aarch64__
807
+ #ifdef USE_NEON
557
808
 
558
809
  template <>
559
810
  struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
@@ -562,14 +813,107 @@ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
562
813
 
563
814
  FAISS_ALWAYS_INLINE float32x4x2_t
564
815
  reconstruct_8_components(const uint8_t* code, int i) const {
565
- float32_t result[8] = {};
566
- for (size_t j = 0; j < 8; j++) {
567
- result[j] = code[i + j];
816
+ uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
817
+ uint16x8_t y8 = vmovl_u8(x8);
818
+ uint16x4_t y8_0 = vget_low_u16(y8);
819
+ uint16x4_t y8_1 = vget_high_u16(y8);
820
+
821
+ // convert uint16 -> uint32 -> fp32
822
+ return {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))};
823
+ }
824
+ };
825
+
826
+ #endif
827
+
828
+ /*******************************************************************
829
+ * 8bit_direct_signed quantizer
830
+ *******************************************************************/
831
+
832
+ template <int SIMDWIDTH>
833
+ struct Quantizer8bitDirectSigned {};
834
+
835
+ template <>
836
+ struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer {
837
+ const size_t d;
838
+
839
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& /* unused */)
840
+ : d(d) {}
841
+
842
+ void encode_vector(const float* x, uint8_t* code) const final {
843
+ for (size_t i = 0; i < d; i++) {
844
+ code[i] = (uint8_t)(x[i] + 128);
568
845
  }
569
- float32x4_t res1 = vld1q_f32(result);
570
- float32x4_t res2 = vld1q_f32(result + 4);
571
- float32x4x2_t res = vzipq_f32(res1, res2);
572
- return vuzpq_f32(res.val[0], res.val[1]);
846
+ }
847
+
848
+ void decode_vector(const uint8_t* code, float* x) const final {
849
+ for (size_t i = 0; i < d; i++) {
850
+ x[i] = code[i] - 128;
851
+ }
852
+ }
853
+
854
+ FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
855
+ const {
856
+ return code[i] - 128;
857
+ }
858
+ };
859
+
860
+ #if defined(__AVX512F__)
861
+
862
+ template <>
863
+ struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> {
864
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
865
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
866
+
867
+ FAISS_ALWAYS_INLINE __m512
868
+ reconstruct_16_components(const uint8_t* code, int i) const {
869
+ __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
870
+ __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
871
+ __m512i c16 = _mm512_set1_epi32(128);
872
+ __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes
873
+ return _mm512_cvtepi32_ps(z16); // 16 * float32
874
+ }
875
+ };
876
+
877
+ #elif defined(__AVX2__)
878
+
879
+ template <>
880
+ struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
881
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
882
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
883
+
884
+ FAISS_ALWAYS_INLINE __m256
885
+ reconstruct_8_components(const uint8_t* code, int i) const {
886
+ __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
887
+ __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
888
+ __m256i c8 = _mm256_set1_epi32(128);
889
+ __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes
890
+ return _mm256_cvtepi32_ps(z8); // 8 * float32
891
+ }
892
+ };
893
+
894
+ #endif
895
+
896
+ #ifdef USE_NEON
897
+
898
+ template <>
899
+ struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
900
+ Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
901
+ : Quantizer8bitDirectSigned<1>(d, trained) {}
902
+
903
+ FAISS_ALWAYS_INLINE float32x4x2_t
904
+ reconstruct_8_components(const uint8_t* code, int i) const {
905
+ uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
906
+ uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16
907
+ uint16x4_t y8_0 = vget_low_u16(y8);
908
+ uint16x4_t y8_1 = vget_high_u16(y8);
909
+
910
+ float32x4_t z8_0 = vcvtq_f32_u32(
911
+ vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32
912
+ float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1));
913
+
914
+ // subtract 128 to convert into signed numbers
915
+ return {vsubq_f32(z8_0, vmovq_n_f32(128.0)),
916
+ vsubq_f32(z8_1, vmovq_n_f32(128.0))};
573
917
  }
574
918
  };
575
919
 
@@ -582,24 +926,38 @@ ScalarQuantizer::SQuantizer* select_quantizer_1(
582
926
  const std::vector<float>& trained) {
583
927
  switch (qtype) {
584
928
  case ScalarQuantizer::QT_8bit:
585
- return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(
586
- d, trained);
929
+ return new QuantizerTemplate<
930
+ Codec8bit,
931
+ QuantizerTemplateScaling::NON_UNIFORM,
932
+ SIMDWIDTH>(d, trained);
587
933
  case ScalarQuantizer::QT_6bit:
588
- return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(
589
- d, trained);
934
+ return new QuantizerTemplate<
935
+ Codec6bit,
936
+ QuantizerTemplateScaling::NON_UNIFORM,
937
+ SIMDWIDTH>(d, trained);
590
938
  case ScalarQuantizer::QT_4bit:
591
- return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(
592
- d, trained);
939
+ return new QuantizerTemplate<
940
+ Codec4bit,
941
+ QuantizerTemplateScaling::NON_UNIFORM,
942
+ SIMDWIDTH>(d, trained);
593
943
  case ScalarQuantizer::QT_8bit_uniform:
594
- return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(
595
- d, trained);
944
+ return new QuantizerTemplate<
945
+ Codec8bit,
946
+ QuantizerTemplateScaling::UNIFORM,
947
+ SIMDWIDTH>(d, trained);
596
948
  case ScalarQuantizer::QT_4bit_uniform:
597
- return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(
598
- d, trained);
949
+ return new QuantizerTemplate<
950
+ Codec4bit,
951
+ QuantizerTemplateScaling::UNIFORM,
952
+ SIMDWIDTH>(d, trained);
599
953
  case ScalarQuantizer::QT_fp16:
600
954
  return new QuantizerFP16<SIMDWIDTH>(d, trained);
955
+ case ScalarQuantizer::QT_bf16:
956
+ return new QuantizerBF16<SIMDWIDTH>(d, trained);
601
957
  case ScalarQuantizer::QT_8bit_direct:
602
958
  return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
959
+ case ScalarQuantizer::QT_8bit_direct_signed:
960
+ return new Quantizer8bitDirectSigned<SIMDWIDTH>(d, trained);
603
961
  }
604
962
  FAISS_THROW_MSG("unknown qtype");
605
963
  }
@@ -816,7 +1174,43 @@ struct SimilarityL2<1> {
816
1174
  }
817
1175
  };
818
1176
 
819
- #ifdef __AVX2__
1177
+ #if defined(__AVX512F__)
1178
+
1179
+ template <>
1180
+ struct SimilarityL2<16> {
1181
+ static constexpr int simdwidth = 16;
1182
+ static constexpr MetricType metric_type = METRIC_L2;
1183
+
1184
+ const float *y, *yi;
1185
+
1186
+ explicit SimilarityL2(const float* y) : y(y) {}
1187
+ __m512 accu16;
1188
+
1189
+ FAISS_ALWAYS_INLINE void begin_16() {
1190
+ accu16 = _mm512_setzero_ps();
1191
+ yi = y;
1192
+ }
1193
+
1194
+ FAISS_ALWAYS_INLINE void add_16_components(__m512 x) {
1195
+ __m512 yiv = _mm512_loadu_ps(yi);
1196
+ yi += 16;
1197
+ __m512 tmp = _mm512_sub_ps(yiv, x);
1198
+ accu16 = _mm512_fmadd_ps(tmp, tmp, accu16);
1199
+ }
1200
+
1201
+ FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x, __m512 y_2) {
1202
+ __m512 tmp = _mm512_sub_ps(y_2, x);
1203
+ accu16 = _mm512_fmadd_ps(tmp, tmp, accu16);
1204
+ }
1205
+
1206
+ FAISS_ALWAYS_INLINE float result_16() {
1207
+ // performs better than dividing into _mm256 and adding
1208
+ return _mm512_reduce_add_ps(accu16);
1209
+ }
1210
+ };
1211
+
1212
+ #elif defined(__AVX2__)
1213
+
820
1214
  template <>
821
1215
  struct SimilarityL2<8> {
822
1216
  static constexpr int simdwidth = 8;
@@ -857,7 +1251,7 @@ struct SimilarityL2<8> {
857
1251
 
858
1252
  #endif
859
1253
 
860
- #ifdef __aarch64__
1254
+ #ifdef USE_NEON
861
1255
  template <>
862
1256
  struct SimilarityL2<8> {
863
1257
  static constexpr int simdwidth = 8;
@@ -868,7 +1262,7 @@ struct SimilarityL2<8> {
868
1262
  float32x4x2_t accu8;
869
1263
 
870
1264
  FAISS_ALWAYS_INLINE void begin_8() {
871
- accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
1265
+ accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
872
1266
  yi = y;
873
1267
  }
874
1268
 
@@ -882,8 +1276,7 @@ struct SimilarityL2<8> {
882
1276
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
883
1277
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
884
1278
 
885
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
886
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1279
+ accu8 = {accu8_0, accu8_1};
887
1280
  }
888
1281
 
889
1282
  FAISS_ALWAYS_INLINE void add_8_components_2(
@@ -895,8 +1288,7 @@ struct SimilarityL2<8> {
895
1288
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
896
1289
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
897
1290
 
898
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
899
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1291
+ accu8 = {accu8_0, accu8_1};
900
1292
  }
901
1293
 
902
1294
  FAISS_ALWAYS_INLINE float result_8() {
@@ -941,7 +1333,43 @@ struct SimilarityIP<1> {
941
1333
  }
942
1334
  };
943
1335
 
944
- #ifdef __AVX2__
1336
+ #if defined(__AVX512F__)
1337
+
1338
+ template <>
1339
+ struct SimilarityIP<16> {
1340
+ static constexpr int simdwidth = 16;
1341
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
1342
+
1343
+ const float *y, *yi;
1344
+
1345
+ float accu;
1346
+
1347
+ explicit SimilarityIP(const float* y) : y(y) {}
1348
+
1349
+ __m512 accu16;
1350
+
1351
+ FAISS_ALWAYS_INLINE void begin_16() {
1352
+ accu16 = _mm512_setzero_ps();
1353
+ yi = y;
1354
+ }
1355
+
1356
+ FAISS_ALWAYS_INLINE void add_16_components(__m512 x) {
1357
+ __m512 yiv = _mm512_loadu_ps(yi);
1358
+ yi += 16;
1359
+ accu16 = _mm512_fmadd_ps(yiv, x, accu16);
1360
+ }
1361
+
1362
+ FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x1, __m512 x2) {
1363
+ accu16 = _mm512_fmadd_ps(x1, x2, accu16);
1364
+ }
1365
+
1366
+ FAISS_ALWAYS_INLINE float result_16() {
1367
+ // performs better than dividing into _mm256 and adding
1368
+ return _mm512_reduce_add_ps(accu16);
1369
+ }
1370
+ };
1371
+
1372
+ #elif defined(__AVX2__)
945
1373
 
946
1374
  template <>
947
1375
  struct SimilarityIP<8> {
@@ -983,7 +1411,7 @@ struct SimilarityIP<8> {
983
1411
  };
984
1412
  #endif
985
1413
 
986
- #ifdef __aarch64__
1414
+ #ifdef USE_NEON
987
1415
 
988
1416
  template <>
989
1417
  struct SimilarityIP<8> {
@@ -996,7 +1424,7 @@ struct SimilarityIP<8> {
996
1424
  float32x4x2_t accu8;
997
1425
 
998
1426
  FAISS_ALWAYS_INLINE void begin_8() {
999
- accu8 = vzipq_f32(vdupq_n_f32(0.0f), vdupq_n_f32(0.0f));
1427
+ accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
1000
1428
  yi = y;
1001
1429
  }
1002
1430
 
@@ -1006,8 +1434,7 @@ struct SimilarityIP<8> {
1006
1434
 
1007
1435
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]);
1008
1436
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]);
1009
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1010
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1437
+ accu8 = {accu8_0, accu8_1};
1011
1438
  }
1012
1439
 
1013
1440
  FAISS_ALWAYS_INLINE void add_8_components_2(
@@ -1015,19 +1442,17 @@ struct SimilarityIP<8> {
1015
1442
  float32x4x2_t x2) {
1016
1443
  float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]);
1017
1444
  float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]);
1018
- float32x4x2_t accu8_temp = vzipq_f32(accu8_0, accu8_1);
1019
- accu8 = vuzpq_f32(accu8_temp.val[0], accu8_temp.val[1]);
1445
+ accu8 = {accu8_0, accu8_1};
1020
1446
  }
1021
1447
 
1022
1448
  FAISS_ALWAYS_INLINE float result_8() {
1023
- float32x4x2_t sum_tmp = vzipq_f32(
1449
+ float32x4x2_t sum = {
1024
1450
  vpaddq_f32(accu8.val[0], accu8.val[0]),
1025
- vpaddq_f32(accu8.val[1], accu8.val[1]));
1026
- float32x4x2_t sum = vuzpq_f32(sum_tmp.val[0], sum_tmp.val[1]);
1027
- float32x4x2_t sum2_tmp = vzipq_f32(
1451
+ vpaddq_f32(accu8.val[1], accu8.val[1])};
1452
+
1453
+ float32x4x2_t sum2 = {
1028
1454
  vpaddq_f32(sum.val[0], sum.val[0]),
1029
- vpaddq_f32(sum.val[1], sum.val[1]));
1030
- float32x4x2_t sum2 = vuzpq_f32(sum2_tmp.val[0], sum2_tmp.val[1]);
1455
+ vpaddq_f32(sum.val[1], sum.val[1])};
1031
1456
  return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
1032
1457
  }
1033
1458
  };
@@ -1086,7 +1511,55 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
1086
1511
  }
1087
1512
  };
1088
1513
 
1089
- #ifdef USE_F16C
1514
+ #if defined(USE_AVX512_F16C)
1515
+
1516
+ template <class Quantizer, class Similarity>
1517
+ struct DCTemplate<Quantizer, Similarity, 16>
1518
+ : SQDistanceComputer { // Update to handle 16 lanes
1519
+ using Sim = Similarity;
1520
+
1521
+ Quantizer quant;
1522
+
1523
+ DCTemplate(size_t d, const std::vector<float>& trained)
1524
+ : quant(d, trained) {}
1525
+
1526
+ float compute_distance(const float* x, const uint8_t* code) const {
1527
+ Similarity sim(x);
1528
+ sim.begin_16();
1529
+ for (size_t i = 0; i < quant.d; i += 16) {
1530
+ __m512 xi = quant.reconstruct_16_components(code, i);
1531
+ sim.add_16_components(xi);
1532
+ }
1533
+ return sim.result_16();
1534
+ }
1535
+
1536
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1537
+ const {
1538
+ Similarity sim(nullptr);
1539
+ sim.begin_16();
1540
+ for (size_t i = 0; i < quant.d; i += 16) {
1541
+ __m512 x1 = quant.reconstruct_16_components(code1, i);
1542
+ __m512 x2 = quant.reconstruct_16_components(code2, i);
1543
+ sim.add_16_components_2(x1, x2);
1544
+ }
1545
+ return sim.result_16();
1546
+ }
1547
+
1548
+ void set_query(const float* x) final {
1549
+ q = x;
1550
+ }
1551
+
1552
+ float symmetric_dis(idx_t i, idx_t j) override {
1553
+ return compute_code_distance(
1554
+ codes + i * code_size, codes + j * code_size);
1555
+ }
1556
+
1557
+ float query_to_code(const uint8_t* code) const final {
1558
+ return compute_distance(q, code);
1559
+ }
1560
+ };
1561
+
1562
+ #elif defined(USE_F16C)
1090
1563
 
1091
1564
  template <class Quantizer, class Similarity>
1092
1565
  struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
@@ -1135,7 +1608,7 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
1135
1608
 
1136
1609
  #endif
1137
1610
 
1138
- #ifdef __aarch64__
1611
+ #ifdef USE_NEON
1139
1612
 
1140
1613
  template <class Quantizer, class Similarity>
1141
1614
  struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
@@ -1233,7 +1706,60 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1233
1706
  }
1234
1707
  };
1235
1708
 
1236
- #ifdef __AVX2__
1709
+ #if defined(__AVX512F__)
1710
+
1711
+ template <class Similarity>
1712
+ struct DistanceComputerByte<Similarity, 16> : SQDistanceComputer {
1713
+ using Sim = Similarity;
1714
+
1715
+ int d;
1716
+ std::vector<uint8_t> tmp;
1717
+
1718
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1719
+
1720
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1721
+ const {
1722
+ __m512i accu = _mm512_setzero_si512();
1723
+ for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time
1724
+ __m512i c1 = _mm512_cvtepu8_epi16(
1725
+ _mm256_loadu_si256((__m256i*)(code1 + i)));
1726
+ __m512i c2 = _mm512_cvtepu8_epi16(
1727
+ _mm256_loadu_si256((__m256i*)(code2 + i)));
1728
+ __m512i prod32;
1729
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1730
+ prod32 = _mm512_madd_epi16(c1, c2);
1731
+ } else {
1732
+ __m512i diff = _mm512_sub_epi16(c1, c2);
1733
+ prod32 = _mm512_madd_epi16(diff, diff);
1734
+ }
1735
+ accu = _mm512_add_epi32(accu, prod32);
1736
+ }
1737
+ // Horizontally add elements of accu
1738
+ return _mm512_reduce_add_epi32(accu);
1739
+ }
1740
+
1741
+ void set_query(const float* x) final {
1742
+ for (int i = 0; i < d; i++) {
1743
+ tmp[i] = int(x[i]);
1744
+ }
1745
+ }
1746
+
1747
+ int compute_distance(const float* x, const uint8_t* code) {
1748
+ set_query(x);
1749
+ return compute_code_distance(tmp.data(), code);
1750
+ }
1751
+
1752
+ float symmetric_dis(idx_t i, idx_t j) override {
1753
+ return compute_code_distance(
1754
+ codes + i * code_size, codes + j * code_size);
1755
+ }
1756
+
1757
+ float query_to_code(const uint8_t* code) const final {
1758
+ return compute_code_distance(tmp.data(), code);
1759
+ }
1760
+ };
1761
+
1762
+ #elif defined(__AVX2__)
1237
1763
 
1238
1764
  template <class Similarity>
1239
1765
  struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
@@ -1298,7 +1824,7 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1298
1824
 
1299
1825
  #endif
1300
1826
 
1301
- #ifdef __aarch64__
1827
+ #ifdef USE_NEON
1302
1828
 
1303
1829
  template <class Similarity>
1304
1830
  struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
@@ -1360,31 +1886,46 @@ SQDistanceComputer* select_distance_computer(
1360
1886
  switch (qtype) {
1361
1887
  case ScalarQuantizer::QT_8bit_uniform:
1362
1888
  return new DCTemplate<
1363
- QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1889
+ QuantizerTemplate<
1890
+ Codec8bit,
1891
+ QuantizerTemplateScaling::UNIFORM,
1892
+ SIMDWIDTH>,
1364
1893
  Sim,
1365
1894
  SIMDWIDTH>(d, trained);
1366
1895
 
1367
1896
  case ScalarQuantizer::QT_4bit_uniform:
1368
1897
  return new DCTemplate<
1369
- QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1898
+ QuantizerTemplate<
1899
+ Codec4bit,
1900
+ QuantizerTemplateScaling::UNIFORM,
1901
+ SIMDWIDTH>,
1370
1902
  Sim,
1371
1903
  SIMDWIDTH>(d, trained);
1372
1904
 
1373
1905
  case ScalarQuantizer::QT_8bit:
1374
1906
  return new DCTemplate<
1375
- QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1907
+ QuantizerTemplate<
1908
+ Codec8bit,
1909
+ QuantizerTemplateScaling::NON_UNIFORM,
1910
+ SIMDWIDTH>,
1376
1911
  Sim,
1377
1912
  SIMDWIDTH>(d, trained);
1378
1913
 
1379
1914
  case ScalarQuantizer::QT_6bit:
1380
1915
  return new DCTemplate<
1381
- QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1916
+ QuantizerTemplate<
1917
+ Codec6bit,
1918
+ QuantizerTemplateScaling::NON_UNIFORM,
1919
+ SIMDWIDTH>,
1382
1920
  Sim,
1383
1921
  SIMDWIDTH>(d, trained);
1384
1922
 
1385
1923
  case ScalarQuantizer::QT_4bit:
1386
1924
  return new DCTemplate<
1387
- QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1925
+ QuantizerTemplate<
1926
+ Codec4bit,
1927
+ QuantizerTemplateScaling::NON_UNIFORM,
1928
+ SIMDWIDTH>,
1388
1929
  Sim,
1389
1930
  SIMDWIDTH>(d, trained);
1390
1931
 
@@ -1392,15 +1933,31 @@ SQDistanceComputer* select_distance_computer(
1392
1933
  return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1393
1934
  d, trained);
1394
1935
 
1936
+ case ScalarQuantizer::QT_bf16:
1937
+ return new DCTemplate<QuantizerBF16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1938
+ d, trained);
1939
+
1395
1940
  case ScalarQuantizer::QT_8bit_direct:
1941
+ #if defined(__AVX512F__)
1942
+ if (d % 32 == 0) {
1943
+ return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1944
+ } else
1945
+ #elif defined(__AVX2__)
1396
1946
  if (d % 16 == 0) {
1397
1947
  return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1398
- } else {
1948
+ } else
1949
+ #endif
1950
+ {
1399
1951
  return new DCTemplate<
1400
1952
  Quantizer8bitDirect<SIMDWIDTH>,
1401
1953
  Sim,
1402
1954
  SIMDWIDTH>(d, trained);
1403
1955
  }
1956
+ case ScalarQuantizer::QT_8bit_direct_signed:
1957
+ return new DCTemplate<
1958
+ Quantizer8bitDirectSigned<SIMDWIDTH>,
1959
+ Sim,
1960
+ SIMDWIDTH>(d, trained);
1404
1961
  }
1405
1962
  FAISS_THROW_MSG("unknown qtype");
1406
1963
  return nullptr;
@@ -1424,6 +1981,7 @@ void ScalarQuantizer::set_derived_sizes() {
1424
1981
  case QT_8bit:
1425
1982
  case QT_8bit_uniform:
1426
1983
  case QT_8bit_direct:
1984
+ case QT_8bit_direct_signed:
1427
1985
  code_size = d;
1428
1986
  bits = 8;
1429
1987
  break;
@@ -1440,6 +1998,10 @@ void ScalarQuantizer::set_derived_sizes() {
1440
1998
  code_size = d * 2;
1441
1999
  bits = 16;
1442
2000
  break;
2001
+ case QT_bf16:
2002
+ code_size = d * 2;
2003
+ bits = 16;
2004
+ break;
1443
2005
  }
1444
2006
  }
1445
2007
 
@@ -1476,13 +2038,19 @@ void ScalarQuantizer::train(size_t n, const float* x) {
1476
2038
  break;
1477
2039
  case QT_fp16:
1478
2040
  case QT_8bit_direct:
2041
+ case QT_bf16:
2042
+ case QT_8bit_direct_signed:
1479
2043
  // no training necessary
1480
2044
  break;
1481
2045
  }
1482
2046
  }
1483
2047
 
1484
2048
  ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
1485
- #if defined(USE_F16C) || defined(__aarch64__)
2049
+ #if defined(USE_AVX512_F16C)
2050
+ if (d % 16 == 0) {
2051
+ return select_quantizer_1<16>(qtype, d, trained);
2052
+ } else
2053
+ #elif defined(USE_F16C) || defined(USE_NEON)
1486
2054
  if (d % 8 == 0) {
1487
2055
  return select_quantizer_1<8>(qtype, d, trained);
1488
2056
  } else
@@ -1513,7 +2081,17 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1513
2081
  SQDistanceComputer* ScalarQuantizer::get_distance_computer(
1514
2082
  MetricType metric) const {
1515
2083
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1516
- #if defined(USE_F16C) || defined(__aarch64__)
2084
+ #if defined(USE_AVX512_F16C)
2085
+ if (d % 16 == 0) {
2086
+ if (metric == METRIC_L2) {
2087
+ return select_distance_computer<SimilarityL2<16>>(
2088
+ qtype, d, trained);
2089
+ } else {
2090
+ return select_distance_computer<SimilarityIP<16>>(
2091
+ qtype, d, trained);
2092
+ }
2093
+ } else
2094
+ #elif defined(USE_F16C) || defined(USE_NEON)
1517
2095
  if (d % 8 == 0) {
1518
2096
  if (metric == METRIC_L2) {
1519
2097
  return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
@@ -1762,7 +2340,7 @@ InvertedListScanner* sel2_InvertedListScanner(
1762
2340
  }
1763
2341
  }
1764
2342
 
1765
- template <class Similarity, class Codec, bool uniform>
2343
+ template <class Similarity, class Codec, QuantizerTemplateScaling SCALING>
1766
2344
  InvertedListScanner* sel12_InvertedListScanner(
1767
2345
  const ScalarQuantizer* sq,
1768
2346
  const Index* quantizer,
@@ -1770,7 +2348,7 @@ InvertedListScanner* sel12_InvertedListScanner(
1770
2348
  const IDSelector* sel,
1771
2349
  bool r) {
1772
2350
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1773
- using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
2351
+ using QuantizerClass = QuantizerTemplate<Codec, SCALING, SIMDWIDTH>;
1774
2352
  using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
1775
2353
  return sel2_InvertedListScanner<DCClass>(
1776
2354
  sq, quantizer, store_pairs, sel, r);
@@ -1786,36 +2364,70 @@ InvertedListScanner* sel1_InvertedListScanner(
1786
2364
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1787
2365
  switch (sq->qtype) {
1788
2366
  case ScalarQuantizer::QT_8bit_uniform:
1789
- return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
2367
+ return sel12_InvertedListScanner<
2368
+ Similarity,
2369
+ Codec8bit,
2370
+ QuantizerTemplateScaling::UNIFORM>(
1790
2371
  sq, quantizer, store_pairs, sel, r);
1791
2372
  case ScalarQuantizer::QT_4bit_uniform:
1792
- return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
2373
+ return sel12_InvertedListScanner<
2374
+ Similarity,
2375
+ Codec4bit,
2376
+ QuantizerTemplateScaling::UNIFORM>(
1793
2377
  sq, quantizer, store_pairs, sel, r);
1794
2378
  case ScalarQuantizer::QT_8bit:
1795
- return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
2379
+ return sel12_InvertedListScanner<
2380
+ Similarity,
2381
+ Codec8bit,
2382
+ QuantizerTemplateScaling::NON_UNIFORM>(
1796
2383
  sq, quantizer, store_pairs, sel, r);
1797
2384
  case ScalarQuantizer::QT_4bit:
1798
- return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
2385
+ return sel12_InvertedListScanner<
2386
+ Similarity,
2387
+ Codec4bit,
2388
+ QuantizerTemplateScaling::NON_UNIFORM>(
1799
2389
  sq, quantizer, store_pairs, sel, r);
1800
2390
  case ScalarQuantizer::QT_6bit:
1801
- return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
2391
+ return sel12_InvertedListScanner<
2392
+ Similarity,
2393
+ Codec6bit,
2394
+ QuantizerTemplateScaling::NON_UNIFORM>(
1802
2395
  sq, quantizer, store_pairs, sel, r);
1803
2396
  case ScalarQuantizer::QT_fp16:
1804
2397
  return sel2_InvertedListScanner<DCTemplate<
1805
2398
  QuantizerFP16<SIMDWIDTH>,
1806
2399
  Similarity,
1807
2400
  SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
2401
+ case ScalarQuantizer::QT_bf16:
2402
+ return sel2_InvertedListScanner<DCTemplate<
2403
+ QuantizerBF16<SIMDWIDTH>,
2404
+ Similarity,
2405
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1808
2406
  case ScalarQuantizer::QT_8bit_direct:
2407
+ #if defined(__AVX512F__)
2408
+ if (sq->d % 32 == 0) {
2409
+ return sel2_InvertedListScanner<
2410
+ DistanceComputerByte<Similarity, SIMDWIDTH>>(
2411
+ sq, quantizer, store_pairs, sel, r);
2412
+ } else
2413
+ #elif defined(__AVX2__)
1809
2414
  if (sq->d % 16 == 0) {
1810
2415
  return sel2_InvertedListScanner<
1811
2416
  DistanceComputerByte<Similarity, SIMDWIDTH>>(
1812
2417
  sq, quantizer, store_pairs, sel, r);
1813
- } else {
2418
+ } else
2419
+ #endif
2420
+ {
1814
2421
  return sel2_InvertedListScanner<DCTemplate<
1815
2422
  Quantizer8bitDirect<SIMDWIDTH>,
1816
2423
  Similarity,
1817
2424
  SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1818
2425
  }
2426
+ case ScalarQuantizer::QT_8bit_direct_signed:
2427
+ return sel2_InvertedListScanner<DCTemplate<
2428
+ Quantizer8bitDirectSigned<SIMDWIDTH>,
2429
+ Similarity,
2430
+ SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
1819
2431
  }
1820
2432
 
1821
2433
  FAISS_THROW_MSG("unknown qtype");
@@ -1849,7 +2461,12 @@ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1849
2461
  bool store_pairs,
1850
2462
  const IDSelector* sel,
1851
2463
  bool by_residual) const {
1852
- #if defined(USE_F16C) || defined(__aarch64__)
2464
+ #if defined(USE_AVX512_F16C)
2465
+ if (d % 16 == 0) {
2466
+ return sel0_InvertedListScanner<16>(
2467
+ mt, this, quantizer, store_pairs, sel, by_residual);
2468
+ } else
2469
+ #elif defined(USE_F16C) || defined(USE_NEON)
1853
2470
  if (d % 8 == 0) {
1854
2471
  return sel0_InvertedListScanner<8>(
1855
2472
  mt, this, quantizer, store_pairs, sel, by_residual);