faiss 0.5.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (169) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/ext/faiss/ext.cpp +1 -1
  5. data/ext/faiss/extconf.rb +5 -6
  6. data/ext/faiss/index_binary.cpp +76 -17
  7. data/ext/faiss/{index.cpp → index_rb.cpp} +108 -35
  8. data/ext/faiss/kmeans.cpp +12 -9
  9. data/ext/faiss/numo.hpp +11 -9
  10. data/ext/faiss/pca_matrix.cpp +10 -8
  11. data/ext/faiss/product_quantizer.cpp +14 -12
  12. data/ext/faiss/{utils.cpp → utils_rb.cpp} +10 -3
  13. data/ext/faiss/{utils.h → utils_rb.h} +6 -0
  14. data/lib/faiss/version.rb +1 -1
  15. data/lib/faiss.rb +1 -1
  16. data/vendor/faiss/faiss/AutoTune.cpp +130 -11
  17. data/vendor/faiss/faiss/AutoTune.h +14 -1
  18. data/vendor/faiss/faiss/Clustering.cpp +59 -10
  19. data/vendor/faiss/faiss/Clustering.h +12 -0
  20. data/vendor/faiss/faiss/IVFlib.cpp +31 -28
  21. data/vendor/faiss/faiss/Index.cpp +20 -8
  22. data/vendor/faiss/faiss/Index.h +25 -3
  23. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +19 -24
  24. data/vendor/faiss/faiss/IndexBinary.cpp +1 -0
  25. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +9 -4
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +45 -11
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +35 -22
  28. data/vendor/faiss/faiss/IndexFastScan.h +10 -1
  29. data/vendor/faiss/faiss/IndexFlat.cpp +193 -136
  30. data/vendor/faiss/faiss/IndexFlat.h +16 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +46 -22
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +24 -50
  34. data/vendor/faiss/faiss/IndexHNSW.h +14 -12
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +1 -1
  36. data/vendor/faiss/faiss/IndexIVF.cpp +76 -49
  37. data/vendor/faiss/faiss/IndexIVF.h +14 -4
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +11 -8
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +2 -2
  40. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +25 -14
  41. data/vendor/faiss/faiss/IndexIVFFastScan.h +26 -22
  42. data/vendor/faiss/faiss/IndexIVFFlat.cpp +10 -61
  43. data/vendor/faiss/faiss/IndexIVFFlatPanorama.cpp +39 -111
  44. data/vendor/faiss/faiss/IndexIVFPQ.cpp +89 -147
  45. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +37 -5
  46. data/vendor/faiss/faiss/IndexIVFPQR.cpp +2 -1
  47. data/vendor/faiss/faiss/IndexIVFRaBitQ.cpp +42 -30
  48. data/vendor/faiss/faiss/IndexIVFRaBitQ.h +2 -2
  49. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.cpp +246 -97
  50. data/vendor/faiss/faiss/IndexIVFRaBitQFastScan.h +32 -29
  51. data/vendor/faiss/faiss/IndexLSH.cpp +8 -6
  52. data/vendor/faiss/faiss/IndexLattice.cpp +29 -24
  53. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -0
  54. data/vendor/faiss/faiss/IndexNSG.cpp +2 -1
  55. data/vendor/faiss/faiss/IndexNSG.h +0 -2
  56. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +1 -1
  57. data/vendor/faiss/faiss/IndexPQ.cpp +19 -10
  58. data/vendor/faiss/faiss/IndexRaBitQ.cpp +26 -13
  59. data/vendor/faiss/faiss/IndexRaBitQ.h +2 -2
  60. data/vendor/faiss/faiss/IndexRaBitQFastScan.cpp +132 -78
  61. data/vendor/faiss/faiss/IndexRaBitQFastScan.h +14 -12
  62. data/vendor/faiss/faiss/IndexRefine.cpp +0 -30
  63. data/vendor/faiss/faiss/IndexShards.cpp +3 -4
  64. data/vendor/faiss/faiss/MetricType.h +16 -0
  65. data/vendor/faiss/faiss/VectorTransform.cpp +120 -0
  66. data/vendor/faiss/faiss/VectorTransform.h +23 -0
  67. data/vendor/faiss/faiss/clone_index.cpp +7 -4
  68. data/vendor/faiss/faiss/{cppcontrib/factory_tools.cpp → factory_tools.cpp} +1 -1
  69. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +1 -1
  70. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +37 -11
  71. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -28
  72. data/vendor/faiss/faiss/impl/ClusteringInitialization.cpp +367 -0
  73. data/vendor/faiss/faiss/impl/ClusteringInitialization.h +107 -0
  74. data/vendor/faiss/faiss/impl/CodePacker.cpp +4 -0
  75. data/vendor/faiss/faiss/impl/CodePacker.h +11 -3
  76. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.cpp +83 -0
  77. data/vendor/faiss/faiss/impl/CodePackerRaBitQ.h +47 -0
  78. data/vendor/faiss/faiss/impl/FaissAssert.h +60 -2
  79. data/vendor/faiss/faiss/impl/HNSW.cpp +25 -34
  80. data/vendor/faiss/faiss/impl/HNSW.h +8 -6
  81. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +34 -27
  82. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -1
  83. data/vendor/faiss/faiss/impl/NSG.cpp +6 -5
  84. data/vendor/faiss/faiss/impl/NSG.h +17 -7
  85. data/vendor/faiss/faiss/impl/Panorama.cpp +53 -46
  86. data/vendor/faiss/faiss/impl/Panorama.h +22 -6
  87. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +16 -5
  88. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +70 -58
  89. data/vendor/faiss/faiss/impl/RaBitQUtils.cpp +92 -0
  90. data/vendor/faiss/faiss/impl/RaBitQUtils.h +93 -31
  91. data/vendor/faiss/faiss/impl/RaBitQuantizer.cpp +12 -28
  92. data/vendor/faiss/faiss/impl/RaBitQuantizer.h +3 -10
  93. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.cpp +15 -41
  94. data/vendor/faiss/faiss/impl/RaBitQuantizerMultiBit.h +0 -4
  95. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +14 -9
  96. data/vendor/faiss/faiss/impl/ResultHandler.h +131 -50
  97. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +67 -2358
  98. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -2
  99. data/vendor/faiss/faiss/impl/VisitedTable.cpp +42 -0
  100. data/vendor/faiss/faiss/impl/VisitedTable.h +69 -0
  101. data/vendor/faiss/faiss/impl/expanded_scanners.h +158 -0
  102. data/vendor/faiss/faiss/impl/index_read.cpp +829 -471
  103. data/vendor/faiss/faiss/impl/index_read_utils.h +0 -1
  104. data/vendor/faiss/faiss/impl/index_write.cpp +17 -8
  105. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +47 -20
  106. data/vendor/faiss/faiss/impl/mapped_io.cpp +9 -2
  107. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +7 -2
  108. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +11 -3
  109. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +19 -13
  110. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +29 -21
  111. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx2.h → pq_code_distance/pq_code_distance-avx2.cpp} +42 -215
  112. data/vendor/faiss/faiss/impl/{code_distance/code_distance-avx512.h → pq_code_distance/pq_code_distance-avx512.cpp} +68 -107
  113. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-generic.cpp +141 -0
  114. data/vendor/faiss/faiss/impl/pq_code_distance/pq_code_distance-inl.h +23 -0
  115. data/vendor/faiss/faiss/impl/{code_distance/code_distance-sve.h → pq_code_distance/pq_code_distance-sve.cpp} +57 -144
  116. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +9 -6
  117. data/vendor/faiss/faiss/impl/scalar_quantizer/codecs.h +121 -0
  118. data/vendor/faiss/faiss/impl/scalar_quantizer/distance_computers.h +136 -0
  119. data/vendor/faiss/faiss/impl/scalar_quantizer/quantizers.h +280 -0
  120. data/vendor/faiss/faiss/impl/scalar_quantizer/scanners.h +164 -0
  121. data/vendor/faiss/faiss/impl/scalar_quantizer/similarities.h +94 -0
  122. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx2.cpp +455 -0
  123. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-avx512.cpp +430 -0
  124. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-dispatch.h +329 -0
  125. data/vendor/faiss/faiss/impl/scalar_quantizer/sq-neon.cpp +467 -0
  126. data/vendor/faiss/faiss/impl/scalar_quantizer/training.cpp +203 -0
  127. data/vendor/faiss/faiss/impl/scalar_quantizer/training.h +42 -0
  128. data/vendor/faiss/faiss/impl/simd_dispatch.h +139 -0
  129. data/vendor/faiss/faiss/impl/simd_result_handlers.h +18 -18
  130. data/vendor/faiss/faiss/index_factory.cpp +35 -16
  131. data/vendor/faiss/faiss/index_io.h +29 -3
  132. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +7 -4
  133. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +1 -1
  134. data/vendor/faiss/faiss/svs/IndexSVSFaissUtils.h +9 -19
  135. data/vendor/faiss/faiss/svs/IndexSVSFlat.h +2 -0
  136. data/vendor/faiss/faiss/svs/IndexSVSVamana.h +2 -1
  137. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.cpp +9 -1
  138. data/vendor/faiss/faiss/svs/IndexSVSVamanaLeanVec.h +9 -0
  139. data/vendor/faiss/faiss/utils/Heap.cpp +46 -0
  140. data/vendor/faiss/faiss/utils/Heap.h +21 -0
  141. data/vendor/faiss/faiss/utils/NeuralNet.cpp +10 -7
  142. data/vendor/faiss/faiss/utils/distances.cpp +141 -23
  143. data/vendor/faiss/faiss/utils/distances.h +98 -0
  144. data/vendor/faiss/faiss/utils/distances_dispatch.h +170 -0
  145. data/vendor/faiss/faiss/utils/distances_simd.cpp +74 -3511
  146. data/vendor/faiss/faiss/utils/extra_distances-inl.h +164 -157
  147. data/vendor/faiss/faiss/utils/extra_distances.cpp +52 -95
  148. data/vendor/faiss/faiss/utils/extra_distances.h +47 -1
  149. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +0 -1
  150. data/vendor/faiss/faiss/utils/partitioning.cpp +1 -1
  151. data/vendor/faiss/faiss/utils/pq_code_distance.h +251 -0
  152. data/vendor/faiss/faiss/utils/rabitq_simd.h +260 -0
  153. data/vendor/faiss/faiss/utils/simd_impl/distances_aarch64.cpp +150 -0
  154. data/vendor/faiss/faiss/utils/simd_impl/distances_arm_sve.cpp +568 -0
  155. data/vendor/faiss/faiss/utils/simd_impl/distances_autovec-inl.h +153 -0
  156. data/vendor/faiss/faiss/utils/simd_impl/distances_avx2.cpp +1185 -0
  157. data/vendor/faiss/faiss/utils/simd_impl/distances_avx512.cpp +1092 -0
  158. data/vendor/faiss/faiss/utils/simd_impl/distances_sse-inl.h +391 -0
  159. data/vendor/faiss/faiss/utils/simd_levels.cpp +322 -0
  160. data/vendor/faiss/faiss/utils/simd_levels.h +91 -0
  161. data/vendor/faiss/faiss/utils/simdlib_avx2.h +12 -1
  162. data/vendor/faiss/faiss/utils/simdlib_avx512.h +69 -0
  163. data/vendor/faiss/faiss/utils/simdlib_neon.h +6 -0
  164. data/vendor/faiss/faiss/utils/sorting.cpp +4 -4
  165. data/vendor/faiss/faiss/utils/utils.cpp +16 -9
  166. metadata +47 -18
  167. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +0 -81
  168. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +0 -186
  169. /data/vendor/faiss/faiss/{cppcontrib/factory_tools.h → factory_tools.h} +0 -0
@@ -7,1971 +7,23 @@
7
7
 
8
8
  // -*- c++ -*-
9
9
 
10
- #include <faiss/impl/ScalarQuantizer.h>
11
-
12
- #include <algorithm>
13
- #include <cstdio>
14
-
15
- #include <faiss/impl/platform_macros.h>
16
-
17
- #ifdef __SSE__
18
- #include <immintrin.h>
19
- #endif
20
-
21
- #include <faiss/IndexIVF.h>
22
- #include <faiss/impl/AuxIndexStructures.h>
23
- #include <faiss/impl/FaissAssert.h>
24
- #include <faiss/impl/IDSelector.h>
25
- #include <faiss/utils/bf16.h>
26
- #include <faiss/utils/fp16.h>
27
- #include <faiss/utils/utils.h>
28
-
29
- namespace faiss {
30
-
31
- /*******************************************************************
32
- * ScalarQuantizer implementation
33
- *
34
- * The main source of complexity is to support combinations of 4
35
- * variants without incurring runtime tests or virtual function calls:
36
- *
37
- * - 4 / 8 bits per code component
38
- * - uniform / non-uniform
39
- * - IP / L2 distance search
40
- * - scalar / AVX distance computation
41
- *
42
- * The appropriate Quantizer object is returned via select_quantizer
43
- * that hides the template mess.
44
- ********************************************************************/
45
-
46
- #if defined(__AVX512F__) && defined(__F16C__)
47
- #define USE_AVX512_F16C
48
- #elif defined(__AVX2__)
49
- #ifdef __F16C__
50
- #define USE_F16C
51
- #else
52
- #warning \
53
- "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well"
54
- #endif
55
- #endif
56
-
57
- #if defined(__aarch64__)
58
- #if defined(__GNUC__) && __GNUC__ < 8
59
- #warning \
60
- "Cannot enable NEON optimizations in scalar quantizer if the compiler is GCC<8"
61
- #else
62
- #define USE_NEON
63
- #endif
64
- #endif
65
-
66
- namespace {
67
-
68
- typedef ScalarQuantizer::QuantizerType QuantizerType;
69
- typedef ScalarQuantizer::RangeStat RangeStat;
70
- using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
71
-
72
- /*******************************************************************
73
- * Codec: converts between values in [0, 1] and an index in a code
74
- * array. The "i" parameter is the vector component index (not byte
75
- * index).
76
- */
77
-
78
- struct Codec8bit {
79
- static FAISS_ALWAYS_INLINE void encode_component(
80
- float x,
81
- uint8_t* code,
82
- int i) {
83
- code[i] = (int)(255 * x);
84
- }
85
-
86
- static FAISS_ALWAYS_INLINE float decode_component(
87
- const uint8_t* code,
88
- int i) {
89
- return (code[i] + 0.5f) / 255.0f;
90
- }
91
-
92
- #if defined(__AVX512F__)
93
- static FAISS_ALWAYS_INLINE __m512
94
- decode_16_components(const uint8_t* code, int i) {
95
- const __m128i c16 = _mm_loadu_si128((__m128i*)(code + i));
96
- const __m512i i32 = _mm512_cvtepu8_epi32(c16);
97
- const __m512 f16 = _mm512_cvtepi32_ps(i32);
98
- const __m512 half_one_255 = _mm512_set1_ps(0.5f / 255.f);
99
- const __m512 one_255 = _mm512_set1_ps(1.f / 255.f);
100
- return _mm512_fmadd_ps(f16, one_255, half_one_255);
101
- }
102
- #elif defined(__AVX2__)
103
- static FAISS_ALWAYS_INLINE __m256
104
- decode_8_components(const uint8_t* code, int i) {
105
- const uint64_t c8 = *(uint64_t*)(code + i);
106
-
107
- const __m128i i8 = _mm_set1_epi64x(c8);
108
- const __m256i i32 = _mm256_cvtepu8_epi32(i8);
109
- const __m256 f8 = _mm256_cvtepi32_ps(i32);
110
- const __m256 half_one_255 = _mm256_set1_ps(0.5f / 255.f);
111
- const __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
112
- return _mm256_fmadd_ps(f8, one_255, half_one_255);
113
- }
114
- #endif
115
-
116
- #ifdef USE_NEON
117
- static FAISS_ALWAYS_INLINE float32x4x2_t
118
- decode_8_components(const uint8_t* code, int i) {
119
- float32_t result[8] = {};
120
- for (size_t j = 0; j < 8; j++) {
121
- result[j] = decode_component(code, i + j);
122
- }
123
- float32x4_t res1 = vld1q_f32(result);
124
- float32x4_t res2 = vld1q_f32(result + 4);
125
- return {res1, res2};
126
- }
127
- #endif
128
- };
129
-
130
- struct Codec4bit {
131
- static FAISS_ALWAYS_INLINE void encode_component(
132
- float x,
133
- uint8_t* code,
134
- int i) {
135
- code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
136
- }
137
-
138
- static FAISS_ALWAYS_INLINE float decode_component(
139
- const uint8_t* code,
140
- int i) {
141
- return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
142
- }
143
-
144
- #if defined(__AVX512F__)
145
- static FAISS_ALWAYS_INLINE __m512
146
- decode_16_components(const uint8_t* code, int i) {
147
- uint64_t c8 = *(uint64_t*)(code + (i >> 1));
148
- uint64_t mask = 0x0f0f0f0f0f0f0f0f;
149
- uint64_t c8ev = c8 & mask;
150
- uint64_t c8od = (c8 >> 4) & mask;
151
-
152
- __m128i c16 =
153
- _mm_unpacklo_epi8(_mm_set1_epi64x(c8ev), _mm_set1_epi64x(c8od));
154
- __m256i c8lo = _mm256_cvtepu8_epi32(c16);
155
- __m256i c8hi = _mm256_cvtepu8_epi32(_mm_srli_si128(c16, 8));
156
- __m512i i16 = _mm512_castsi256_si512(c8lo);
157
- i16 = _mm512_inserti32x8(i16, c8hi, 1);
158
- __m512 f16 = _mm512_cvtepi32_ps(i16);
159
- const __m512 half_one_255 = _mm512_set1_ps(0.5f / 15.f);
160
- const __m512 one_255 = _mm512_set1_ps(1.f / 15.f);
161
- return _mm512_fmadd_ps(f16, one_255, half_one_255);
162
- }
163
- #elif defined(__AVX2__)
164
- static FAISS_ALWAYS_INLINE __m256
165
- decode_8_components(const uint8_t* code, int i) {
166
- uint32_t c4 = *(uint32_t*)(code + (i >> 1));
167
- uint32_t mask = 0x0f0f0f0f;
168
- uint32_t c4ev = c4 & mask;
169
- uint32_t c4od = (c4 >> 4) & mask;
170
-
171
- // the 8 lower bytes of c8 contain the values
172
- __m128i c8 =
173
- _mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od));
174
- __m128i c4lo = _mm_cvtepu8_epi32(c8);
175
- __m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4));
176
- __m256i i8 = _mm256_castsi128_si256(c4lo);
177
- i8 = _mm256_insertf128_si256(i8, c4hi, 1);
178
- __m256 f8 = _mm256_cvtepi32_ps(i8);
179
- __m256 half = _mm256_set1_ps(0.5f);
180
- f8 = _mm256_add_ps(f8, half);
181
- __m256 one_255 = _mm256_set1_ps(1.f / 15.f);
182
- return _mm256_mul_ps(f8, one_255);
183
- }
184
- #endif
185
-
186
- #ifdef USE_NEON
187
- static FAISS_ALWAYS_INLINE float32x4x2_t
188
- decode_8_components(const uint8_t* code, int i) {
189
- float32_t result[8] = {};
190
- for (size_t j = 0; j < 8; j++) {
191
- result[j] = decode_component(code, i + j);
192
- }
193
- float32x4_t res1 = vld1q_f32(result);
194
- float32x4_t res2 = vld1q_f32(result + 4);
195
- return {res1, res2};
196
- }
197
- #endif
198
- };
199
-
200
- struct Codec6bit {
201
- static FAISS_ALWAYS_INLINE void encode_component(
202
- float x,
203
- uint8_t* code,
204
- int i) {
205
- int bits = (int)(x * 63.0);
206
- code += (i >> 2) * 3;
207
- switch (i & 3) {
208
- case 0:
209
- code[0] |= bits;
210
- break;
211
- case 1:
212
- code[0] |= bits << 6;
213
- code[1] |= bits >> 2;
214
- break;
215
- case 2:
216
- code[1] |= bits << 4;
217
- code[2] |= bits >> 4;
218
- break;
219
- case 3:
220
- code[2] |= bits << 2;
221
- break;
222
- }
223
- }
224
-
225
- static FAISS_ALWAYS_INLINE float decode_component(
226
- const uint8_t* code,
227
- int i) {
228
- uint8_t bits;
229
- code += (i >> 2) * 3;
230
- switch (i & 3) {
231
- case 0:
232
- bits = code[0] & 0x3f;
233
- break;
234
- case 1:
235
- bits = code[0] >> 6;
236
- bits |= (code[1] & 0xf) << 2;
237
- break;
238
- case 2:
239
- bits = code[1] >> 4;
240
- bits |= (code[2] & 3) << 4;
241
- break;
242
- case 3:
243
- bits = code[2] >> 2;
244
- break;
245
- }
246
- return (bits + 0.5f) / 63.0f;
247
- }
248
-
249
- #if defined(__AVX512F__)
250
-
251
- static FAISS_ALWAYS_INLINE __m512
252
- decode_16_components(const uint8_t* code, int i) {
253
- // pure AVX512 implementation (not necessarily the fastest).
254
- // see:
255
- // https://github.com/zilliztech/knowhere/blob/main/thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
256
-
257
- // clang-format off
258
-
259
- // 16 components, 16x6 bit=12 bytes
260
- const __m128i bit_6v =
261
- _mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3);
262
- const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v);
263
-
264
- // 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F
265
- // 00 01 02 03
266
- const __m256i shuffle_mask = _mm256_setr_epi16(
267
- 0xFF00, 0x0100, 0x0201, 0xFF02,
268
- 0xFF03, 0x0403, 0x0504, 0xFF05,
269
- 0xFF06, 0x0706, 0x0807, 0xFF08,
270
- 0xFF09, 0x0A09, 0x0B0A, 0xFF0B);
271
- const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask);
272
-
273
- // 0: xxxxxxxx xx543210
274
- // 1: xxxx5432 10xxxxxx
275
- // 2: xxxxxx54 3210xxxx
276
- // 3: xxxxxxxx 543210xx
277
- const __m256i shift_right_v = _mm256_setr_epi16(
278
- 0x0U, 0x6U, 0x4U, 0x2U,
279
- 0x0U, 0x6U, 0x4U, 0x2U,
280
- 0x0U, 0x6U, 0x4U, 0x2U,
281
- 0x0U, 0x6U, 0x4U, 0x2U);
282
- __m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v);
283
-
284
- // remove unneeded bits
285
- shuffled_shifted =
286
- _mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F));
287
-
288
- // scale
289
- const __m512 f8 =
290
- _mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted));
291
- const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
292
- const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
293
- return _mm512_fmadd_ps(f8, one_255, half_one_255);
294
-
295
- // clang-format on
296
- }
297
-
298
- #elif defined(__AVX2__)
299
-
300
- /* Load 6 bytes that represent 8 6-bit values, return them as a
301
- * 8*32 bit vector register */
302
- static FAISS_ALWAYS_INLINE __m256i load6(const uint16_t* code16) {
303
- const __m128i perm = _mm_set_epi8(
304
- -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
305
- const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0);
306
-
307
- // load 6 bytes
308
- __m128i c1 =
309
- _mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]);
310
-
311
- // put in 8 * 32 bits
312
- __m128i c2 = _mm_shuffle_epi8(c1, perm);
313
- __m256i c3 = _mm256_cvtepi16_epi32(c2);
314
-
315
- // shift and mask out useless bits
316
- __m256i c4 = _mm256_srlv_epi32(c3, shifts);
317
- __m256i c5 = _mm256_and_si256(_mm256_set1_epi32(63), c4);
318
- return c5;
319
- }
320
-
321
- static FAISS_ALWAYS_INLINE __m256
322
- decode_8_components(const uint8_t* code, int i) {
323
- // // Faster code for Intel CPUs or AMD Zen3+, just keeping it here
324
- // // for the reference, maybe, it becomes used one day.
325
- // const uint16_t* data16 = (const uint16_t*)(code + (i >> 2) * 3);
326
- // const uint32_t* data32 = (const uint32_t*)data16;
327
- // const uint64_t val = *data32 + ((uint64_t)data16[2] << 32);
328
- // const uint64_t vext = _pdep_u64(val, 0x3F3F3F3F3F3F3F3FULL);
329
- // const __m128i i8 = _mm_set1_epi64x(vext);
330
- // const __m256i i32 = _mm256_cvtepi8_epi32(i8);
331
- // const __m256 f8 = _mm256_cvtepi32_ps(i32);
332
- // const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
333
- // const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
334
- // return _mm256_fmadd_ps(f8, one_255, half_one_255);
335
-
336
- __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3));
337
- __m256 f8 = _mm256_cvtepi32_ps(i8);
338
- // this could also be done with bit manipulations but it is
339
- // not obviously faster
340
- const __m256 half_one_255 = _mm256_set1_ps(0.5f / 63.f);
341
- const __m256 one_255 = _mm256_set1_ps(1.f / 63.f);
342
- return _mm256_fmadd_ps(f8, one_255, half_one_255);
343
- }
344
-
345
- #endif
346
-
347
- #ifdef USE_NEON
348
- static FAISS_ALWAYS_INLINE float32x4x2_t
349
- decode_8_components(const uint8_t* code, int i) {
350
- float32_t result[8] = {};
351
- for (size_t j = 0; j < 8; j++) {
352
- result[j] = decode_component(code, i + j);
353
- }
354
- float32x4_t res1 = vld1q_f32(result);
355
- float32x4_t res2 = vld1q_f32(result + 4);
356
- return {res1, res2};
357
- }
358
- #endif
359
- };
360
-
361
- /*******************************************************************
362
- * Quantizer: normalizes scalar vector components, then passes them
363
- * through a codec
364
- *******************************************************************/
365
-
366
- enum class QuantizerTemplateScaling { UNIFORM = 0, NON_UNIFORM = 1 };
367
-
368
- template <class Codec, QuantizerTemplateScaling SCALING, int SIMD>
369
- struct QuantizerTemplate {};
370
-
371
- template <class Codec>
372
- struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>
373
- : ScalarQuantizer::SQuantizer {
374
- const size_t d;
375
- const float vmin, vdiff;
376
-
377
- QuantizerTemplate(size_t d, const std::vector<float>& trained)
378
- : d(d), vmin(trained[0]), vdiff(trained[1]) {}
379
-
380
- void encode_vector(const float* x, uint8_t* code) const final {
381
- for (size_t i = 0; i < d; i++) {
382
- float xi = 0;
383
- if (vdiff != 0) {
384
- xi = (x[i] - vmin) / vdiff;
385
- if (xi < 0) {
386
- xi = 0;
387
- }
388
- if (xi > 1.0) {
389
- xi = 1.0;
390
- }
391
- }
392
- Codec::encode_component(xi, code, i);
393
- }
394
- }
395
-
396
- void decode_vector(const uint8_t* code, float* x) const final {
397
- for (size_t i = 0; i < d; i++) {
398
- float xi = Codec::decode_component(code, i);
399
- x[i] = vmin + xi * vdiff;
400
- }
401
- }
402
-
403
- FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
404
- const {
405
- float xi = Codec::decode_component(code, i);
406
- return vmin + xi * vdiff;
407
- }
408
- };
409
-
410
- #if defined(__AVX512F__)
411
-
412
- template <class Codec>
413
- struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 16>
414
- : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
415
- QuantizerTemplate(size_t d, const std::vector<float>& trained)
416
- : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
417
- d,
418
- trained) {}
419
-
420
- FAISS_ALWAYS_INLINE __m512
421
- reconstruct_16_components(const uint8_t* code, int i) const {
422
- __m512 xi = Codec::decode_16_components(code, i);
423
- return _mm512_fmadd_ps(
424
- xi, _mm512_set1_ps(this->vdiff), _mm512_set1_ps(this->vmin));
425
- }
426
- };
427
-
428
- #elif defined(__AVX2__)
429
-
430
- template <class Codec>
431
- struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
432
- : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
433
- QuantizerTemplate(size_t d, const std::vector<float>& trained)
434
- : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
435
- d,
436
- trained) {}
437
-
438
- FAISS_ALWAYS_INLINE __m256
439
- reconstruct_8_components(const uint8_t* code, int i) const {
440
- __m256 xi = Codec::decode_8_components(code, i);
441
- return _mm256_fmadd_ps(
442
- xi, _mm256_set1_ps(this->vdiff), _mm256_set1_ps(this->vmin));
443
- }
444
- };
445
-
446
- #endif
447
-
448
- #ifdef USE_NEON
449
-
450
- template <class Codec>
451
- struct QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 8>
452
- : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1> {
453
- QuantizerTemplate(size_t d, const std::vector<float>& trained)
454
- : QuantizerTemplate<Codec, QuantizerTemplateScaling::UNIFORM, 1>(
455
- d,
456
- trained) {}
457
-
458
- FAISS_ALWAYS_INLINE float32x4x2_t
459
- reconstruct_8_components(const uint8_t* code, int i) const {
460
- float32x4x2_t xi = Codec::decode_8_components(code, i);
461
- return {vfmaq_f32(
462
- vdupq_n_f32(this->vmin),
463
- xi.val[0],
464
- vdupq_n_f32(this->vdiff)),
465
- vfmaq_f32(
466
- vdupq_n_f32(this->vmin),
467
- xi.val[1],
468
- vdupq_n_f32(this->vdiff))};
469
- }
470
- };
471
-
472
- #endif
473
-
474
- template <class Codec>
475
- struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1>
476
- : ScalarQuantizer::SQuantizer {
477
- const size_t d;
478
- const float *vmin, *vdiff;
479
-
480
- QuantizerTemplate(size_t d, const std::vector<float>& trained)
481
- : d(d), vmin(trained.data()), vdiff(trained.data() + d) {}
482
-
483
- void encode_vector(const float* x, uint8_t* code) const final {
484
- for (size_t i = 0; i < d; i++) {
485
- float xi = 0;
486
- if (vdiff[i] != 0) {
487
- xi = (x[i] - vmin[i]) / vdiff[i];
488
- if (xi < 0) {
489
- xi = 0;
490
- }
491
- if (xi > 1.0) {
492
- xi = 1.0;
493
- }
494
- }
495
- Codec::encode_component(xi, code, i);
496
- }
497
- }
498
-
499
- void decode_vector(const uint8_t* code, float* x) const final {
500
- for (size_t i = 0; i < d; i++) {
501
- float xi = Codec::decode_component(code, i);
502
- x[i] = vmin[i] + xi * vdiff[i];
503
- }
504
- }
505
-
506
- FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
507
- const {
508
- float xi = Codec::decode_component(code, i);
509
- return vmin[i] + xi * vdiff[i];
510
- }
511
- };
512
-
513
- #if defined(__AVX512F__)
514
-
515
- template <class Codec>
516
- struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 16>
517
- : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
518
- QuantizerTemplate(size_t d, const std::vector<float>& trained)
519
- : QuantizerTemplate<
520
- Codec,
521
- QuantizerTemplateScaling::NON_UNIFORM,
522
- 1>(d, trained) {}
523
-
524
- FAISS_ALWAYS_INLINE __m512
525
- reconstruct_16_components(const uint8_t* code, int i) const {
526
- __m512 xi = Codec::decode_16_components(code, i);
527
- return _mm512_fmadd_ps(
528
- xi,
529
- _mm512_loadu_ps(this->vdiff + i),
530
- _mm512_loadu_ps(this->vmin + i));
531
- }
532
- };
533
-
534
- #elif defined(__AVX2__)
535
-
536
- template <class Codec>
537
- struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
538
- : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
539
- QuantizerTemplate(size_t d, const std::vector<float>& trained)
540
- : QuantizerTemplate<
541
- Codec,
542
- QuantizerTemplateScaling::NON_UNIFORM,
543
- 1>(d, trained) {}
544
-
545
- FAISS_ALWAYS_INLINE __m256
546
- reconstruct_8_components(const uint8_t* code, int i) const {
547
- __m256 xi = Codec::decode_8_components(code, i);
548
- return _mm256_fmadd_ps(
549
- xi,
550
- _mm256_loadu_ps(this->vdiff + i),
551
- _mm256_loadu_ps(this->vmin + i));
552
- }
553
- };
554
-
555
- #endif
556
-
557
- #ifdef USE_NEON
558
-
559
- template <class Codec>
560
- struct QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 8>
561
- : QuantizerTemplate<Codec, QuantizerTemplateScaling::NON_UNIFORM, 1> {
562
- QuantizerTemplate(size_t d, const std::vector<float>& trained)
563
- : QuantizerTemplate<
564
- Codec,
565
- QuantizerTemplateScaling::NON_UNIFORM,
566
- 1>(d, trained) {}
567
-
568
- FAISS_ALWAYS_INLINE float32x4x2_t
569
- reconstruct_8_components(const uint8_t* code, int i) const {
570
- float32x4x2_t xi = Codec::decode_8_components(code, i);
571
-
572
- float32x4x2_t vmin_8 = vld1q_f32_x2(this->vmin + i);
573
- float32x4x2_t vdiff_8 = vld1q_f32_x2(this->vdiff + i);
574
-
575
- return {vfmaq_f32(vmin_8.val[0], xi.val[0], vdiff_8.val[0]),
576
- vfmaq_f32(vmin_8.val[1], xi.val[1], vdiff_8.val[1])};
577
- }
578
- };
579
-
580
- #endif
581
-
582
- /*******************************************************************
583
- * FP16 quantizer
584
- *******************************************************************/
585
-
586
- template <int SIMDWIDTH>
587
- struct QuantizerFP16 {};
588
-
589
- template <>
590
- struct QuantizerFP16<1> : ScalarQuantizer::SQuantizer {
591
- const size_t d;
592
-
593
- QuantizerFP16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
594
-
595
- void encode_vector(const float* x, uint8_t* code) const final {
596
- for (size_t i = 0; i < d; i++) {
597
- ((uint16_t*)code)[i] = encode_fp16(x[i]);
598
- }
599
- }
600
-
601
- void decode_vector(const uint8_t* code, float* x) const final {
602
- for (size_t i = 0; i < d; i++) {
603
- x[i] = decode_fp16(((uint16_t*)code)[i]);
604
- }
605
- }
606
-
607
- FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
608
- const {
609
- return decode_fp16(((uint16_t*)code)[i]);
610
- }
611
- };
612
-
613
- #if defined(USE_AVX512_F16C)
614
-
615
- template <>
616
- struct QuantizerFP16<16> : QuantizerFP16<1> {
617
- QuantizerFP16(size_t d, const std::vector<float>& trained)
618
- : QuantizerFP16<1>(d, trained) {}
619
-
620
- FAISS_ALWAYS_INLINE __m512
621
- reconstruct_16_components(const uint8_t* code, int i) const {
622
- __m256i codei = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
623
- return _mm512_cvtph_ps(codei);
624
- }
625
- };
626
-
627
- #endif
628
-
629
- #if defined(USE_F16C)
630
-
631
- template <>
632
- struct QuantizerFP16<8> : QuantizerFP16<1> {
633
- QuantizerFP16(size_t d, const std::vector<float>& trained)
634
- : QuantizerFP16<1>(d, trained) {}
635
-
636
- FAISS_ALWAYS_INLINE __m256
637
- reconstruct_8_components(const uint8_t* code, int i) const {
638
- __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i));
639
- return _mm256_cvtph_ps(codei);
640
- }
641
- };
642
-
643
- #endif
644
-
645
- #ifdef USE_NEON
646
-
647
- template <>
648
- struct QuantizerFP16<8> : QuantizerFP16<1> {
649
- QuantizerFP16(size_t d, const std::vector<float>& trained)
650
- : QuantizerFP16<1>(d, trained) {}
651
-
652
- FAISS_ALWAYS_INLINE float32x4x2_t
653
- reconstruct_8_components(const uint8_t* code, int i) const {
654
- uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
655
- return {vcvt_f32_f16(vreinterpret_f16_u16(codei.val[0])),
656
- vcvt_f32_f16(vreinterpret_f16_u16(codei.val[1]))};
657
- }
658
- };
659
- #endif
660
-
661
- /*******************************************************************
662
- * BF16 quantizer
663
- *******************************************************************/
664
-
665
- template <int SIMDWIDTH>
666
- struct QuantizerBF16 {};
667
-
668
- template <>
669
- struct QuantizerBF16<1> : ScalarQuantizer::SQuantizer {
670
- const size_t d;
671
-
672
- QuantizerBF16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
673
-
674
- void encode_vector(const float* x, uint8_t* code) const final {
675
- for (size_t i = 0; i < d; i++) {
676
- ((uint16_t*)code)[i] = encode_bf16(x[i]);
677
- }
678
- }
679
-
680
- void decode_vector(const uint8_t* code, float* x) const final {
681
- for (size_t i = 0; i < d; i++) {
682
- x[i] = decode_bf16(((uint16_t*)code)[i]);
683
- }
684
- }
685
-
686
- FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
687
- const {
688
- return decode_bf16(((uint16_t*)code)[i]);
689
- }
690
- };
691
-
692
- #if defined(__AVX512F__)
693
-
694
- template <>
695
- struct QuantizerBF16<16> : QuantizerBF16<1> {
696
- QuantizerBF16(size_t d, const std::vector<float>& trained)
697
- : QuantizerBF16<1>(d, trained) {}
698
- FAISS_ALWAYS_INLINE __m512
699
- reconstruct_16_components(const uint8_t* code, int i) const {
700
- __m256i code_256i = _mm256_loadu_si256((const __m256i*)(code + 2 * i));
701
- __m512i code_512i = _mm512_cvtepu16_epi32(code_256i);
702
- code_512i = _mm512_slli_epi32(code_512i, 16);
703
- return _mm512_castsi512_ps(code_512i);
704
- }
705
- };
706
-
707
- #elif defined(__AVX2__)
708
-
709
- template <>
710
- struct QuantizerBF16<8> : QuantizerBF16<1> {
711
- QuantizerBF16(size_t d, const std::vector<float>& trained)
712
- : QuantizerBF16<1>(d, trained) {}
713
-
714
- FAISS_ALWAYS_INLINE __m256
715
- reconstruct_8_components(const uint8_t* code, int i) const {
716
- __m128i code_128i = _mm_loadu_si128((const __m128i*)(code + 2 * i));
717
- __m256i code_256i = _mm256_cvtepu16_epi32(code_128i);
718
- code_256i = _mm256_slli_epi32(code_256i, 16);
719
- return _mm256_castsi256_ps(code_256i);
720
- }
721
- };
722
-
723
- #endif
724
-
725
- #ifdef USE_NEON
726
-
727
- template <>
728
- struct QuantizerBF16<8> : QuantizerBF16<1> {
729
- QuantizerBF16(size_t d, const std::vector<float>& trained)
730
- : QuantizerBF16<1>(d, trained) {}
731
-
732
- FAISS_ALWAYS_INLINE float32x4x2_t
733
- reconstruct_8_components(const uint8_t* code, int i) const {
734
- uint16x4x2_t codei = vld1_u16_x2((const uint16_t*)(code + 2 * i));
735
- return {vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(codei.val[0]), 16)),
736
- vreinterpretq_f32_u32(
737
- vshlq_n_u32(vmovl_u16(codei.val[1]), 16))};
738
- }
739
- };
740
- #endif
741
-
742
- /*******************************************************************
743
- * 8bit_direct quantizer
744
- *******************************************************************/
745
-
746
- template <int SIMDWIDTH>
747
- struct Quantizer8bitDirect {};
748
-
749
- template <>
750
- struct Quantizer8bitDirect<1> : ScalarQuantizer::SQuantizer {
751
- const size_t d;
752
-
753
- Quantizer8bitDirect(size_t d, const std::vector<float>& /* unused */)
754
- : d(d) {}
755
-
756
- void encode_vector(const float* x, uint8_t* code) const final {
757
- for (size_t i = 0; i < d; i++) {
758
- code[i] = (uint8_t)x[i];
759
- }
760
- }
761
-
762
- void decode_vector(const uint8_t* code, float* x) const final {
763
- for (size_t i = 0; i < d; i++) {
764
- x[i] = code[i];
765
- }
766
- }
767
-
768
- FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
769
- const {
770
- return code[i];
771
- }
772
- };
773
-
774
- #if defined(__AVX512F__)
775
-
776
- template <>
777
- struct Quantizer8bitDirect<16> : Quantizer8bitDirect<1> {
778
- Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
779
- : Quantizer8bitDirect<1>(d, trained) {}
780
-
781
- FAISS_ALWAYS_INLINE __m512
782
- reconstruct_16_components(const uint8_t* code, int i) const {
783
- __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
784
- __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
785
- return _mm512_cvtepi32_ps(y16); // 16 * float32
786
- }
787
- };
788
-
789
- #elif defined(__AVX2__)
790
-
791
- template <>
792
- struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
793
- Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
794
- : Quantizer8bitDirect<1>(d, trained) {}
795
-
796
- FAISS_ALWAYS_INLINE __m256
797
- reconstruct_8_components(const uint8_t* code, int i) const {
798
- __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
799
- __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
800
- return _mm256_cvtepi32_ps(y8); // 8 * float32
801
- }
802
- };
803
-
804
- #endif
805
-
806
- #ifdef USE_NEON
807
-
808
- template <>
809
- struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
810
- Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
811
- : Quantizer8bitDirect<1>(d, trained) {}
812
-
813
- FAISS_ALWAYS_INLINE float32x4x2_t
814
- reconstruct_8_components(const uint8_t* code, int i) const {
815
- uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
816
- uint16x8_t y8 = vmovl_u8(x8);
817
- uint16x4_t y8_0 = vget_low_u16(y8);
818
- uint16x4_t y8_1 = vget_high_u16(y8);
819
-
820
- // convert uint16 -> uint32 -> fp32
821
- return {vcvtq_f32_u32(vmovl_u16(y8_0)), vcvtq_f32_u32(vmovl_u16(y8_1))};
822
- }
823
- };
824
-
825
- #endif
826
-
827
- /*******************************************************************
828
- * 8bit_direct_signed quantizer
829
- *******************************************************************/
830
-
831
- template <int SIMDWIDTH>
832
- struct Quantizer8bitDirectSigned {};
833
-
834
- template <>
835
- struct Quantizer8bitDirectSigned<1> : ScalarQuantizer::SQuantizer {
836
- const size_t d;
837
-
838
- Quantizer8bitDirectSigned(size_t d, const std::vector<float>& /* unused */)
839
- : d(d) {}
840
-
841
- void encode_vector(const float* x, uint8_t* code) const final {
842
- for (size_t i = 0; i < d; i++) {
843
- code[i] = (uint8_t)(x[i] + 128);
844
- }
845
- }
846
-
847
- void decode_vector(const uint8_t* code, float* x) const final {
848
- for (size_t i = 0; i < d; i++) {
849
- x[i] = code[i] - 128;
850
- }
851
- }
852
-
853
- FAISS_ALWAYS_INLINE float reconstruct_component(const uint8_t* code, int i)
854
- const {
855
- return code[i] - 128;
856
- }
857
- };
858
-
859
- #if defined(__AVX512F__)
860
-
861
- template <>
862
- struct Quantizer8bitDirectSigned<16> : Quantizer8bitDirectSigned<1> {
863
- Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
864
- : Quantizer8bitDirectSigned<1>(d, trained) {}
865
-
866
- FAISS_ALWAYS_INLINE __m512
867
- reconstruct_16_components(const uint8_t* code, int i) const {
868
- __m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
869
- __m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
870
- __m512i c16 = _mm512_set1_epi32(128);
871
- __m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes
872
- return _mm512_cvtepi32_ps(z16); // 16 * float32
873
- }
874
- };
875
-
876
- #elif defined(__AVX2__)
877
-
878
- template <>
879
- struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
880
- Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
881
- : Quantizer8bitDirectSigned<1>(d, trained) {}
882
-
883
- FAISS_ALWAYS_INLINE __m256
884
- reconstruct_8_components(const uint8_t* code, int i) const {
885
- __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
886
- __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
887
- __m256i c8 = _mm256_set1_epi32(128);
888
- __m256i z8 = _mm256_sub_epi32(y8, c8); // subtract 128 from all lanes
889
- return _mm256_cvtepi32_ps(z8); // 8 * float32
890
- }
891
- };
892
-
893
- #endif
894
-
895
- #ifdef USE_NEON
896
-
897
- template <>
898
- struct Quantizer8bitDirectSigned<8> : Quantizer8bitDirectSigned<1> {
899
- Quantizer8bitDirectSigned(size_t d, const std::vector<float>& trained)
900
- : Quantizer8bitDirectSigned<1>(d, trained) {}
901
-
902
- FAISS_ALWAYS_INLINE float32x4x2_t
903
- reconstruct_8_components(const uint8_t* code, int i) const {
904
- uint8x8_t x8 = vld1_u8((const uint8_t*)(code + i));
905
- uint16x8_t y8 = vmovl_u8(x8); // convert uint8 -> uint16
906
- uint16x4_t y8_0 = vget_low_u16(y8);
907
- uint16x4_t y8_1 = vget_high_u16(y8);
908
-
909
- float32x4_t z8_0 = vcvtq_f32_u32(
910
- vmovl_u16(y8_0)); // convert uint16 -> uint32 -> fp32
911
- float32x4_t z8_1 = vcvtq_f32_u32(vmovl_u16(y8_1));
912
-
913
- // subtract 128 to convert into signed numbers
914
- return {vsubq_f32(z8_0, vmovq_n_f32(128.0)),
915
- vsubq_f32(z8_1, vmovq_n_f32(128.0))};
916
- }
917
- };
918
-
919
- #endif
920
-
921
- template <int SIMDWIDTH>
922
- ScalarQuantizer::SQuantizer* select_quantizer_1(
923
- QuantizerType qtype,
924
- size_t d,
925
- const std::vector<float>& trained) {
926
- switch (qtype) {
927
- case ScalarQuantizer::QT_8bit:
928
- return new QuantizerTemplate<
929
- Codec8bit,
930
- QuantizerTemplateScaling::NON_UNIFORM,
931
- SIMDWIDTH>(d, trained);
932
- case ScalarQuantizer::QT_6bit:
933
- return new QuantizerTemplate<
934
- Codec6bit,
935
- QuantizerTemplateScaling::NON_UNIFORM,
936
- SIMDWIDTH>(d, trained);
937
- case ScalarQuantizer::QT_4bit:
938
- return new QuantizerTemplate<
939
- Codec4bit,
940
- QuantizerTemplateScaling::NON_UNIFORM,
941
- SIMDWIDTH>(d, trained);
942
- case ScalarQuantizer::QT_8bit_uniform:
943
- return new QuantizerTemplate<
944
- Codec8bit,
945
- QuantizerTemplateScaling::UNIFORM,
946
- SIMDWIDTH>(d, trained);
947
- case ScalarQuantizer::QT_4bit_uniform:
948
- return new QuantizerTemplate<
949
- Codec4bit,
950
- QuantizerTemplateScaling::UNIFORM,
951
- SIMDWIDTH>(d, trained);
952
- case ScalarQuantizer::QT_fp16:
953
- return new QuantizerFP16<SIMDWIDTH>(d, trained);
954
- case ScalarQuantizer::QT_bf16:
955
- return new QuantizerBF16<SIMDWIDTH>(d, trained);
956
- case ScalarQuantizer::QT_8bit_direct:
957
- return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
958
- case ScalarQuantizer::QT_8bit_direct_signed:
959
- return new Quantizer8bitDirectSigned<SIMDWIDTH>(d, trained);
960
- }
961
- FAISS_THROW_MSG("unknown qtype");
962
- }
963
-
964
- /*******************************************************************
965
- * Quantizer range training
966
- */
967
-
968
- static float sqr(float x) {
969
- return x * x;
970
- }
971
-
972
- void train_Uniform(
973
- RangeStat rs,
974
- float rs_arg,
975
- idx_t n,
976
- int k,
977
- const float* x,
978
- std::vector<float>& trained) {
979
- trained.resize(2);
980
- float& vmin = trained[0];
981
- float& vmax = trained[1];
982
-
983
- if (rs == ScalarQuantizer::RS_minmax) {
984
- vmin = HUGE_VAL;
985
- vmax = -HUGE_VAL;
986
- for (size_t i = 0; i < n; i++) {
987
- if (x[i] < vmin) {
988
- vmin = x[i];
989
- }
990
- if (x[i] > vmax) {
991
- vmax = x[i];
992
- }
993
- }
994
- float vexp = (vmax - vmin) * rs_arg;
995
- vmin -= vexp;
996
- vmax += vexp;
997
- } else if (rs == ScalarQuantizer::RS_meanstd) {
998
- double sum = 0, sum2 = 0;
999
- for (size_t i = 0; i < n; i++) {
1000
- sum += x[i];
1001
- sum2 += x[i] * x[i];
1002
- }
1003
- float mean = sum / n;
1004
- float var = sum2 / n - mean * mean;
1005
- float std = var <= 0 ? 1.0 : sqrt(var);
1006
-
1007
- vmin = mean - std * rs_arg;
1008
- vmax = mean + std * rs_arg;
1009
- } else if (rs == ScalarQuantizer::RS_quantiles) {
1010
- std::vector<float> x_copy(n);
1011
- memcpy(x_copy.data(), x, n * sizeof(*x));
1012
- int temp = int(rs_arg * n);
1013
- int o = temp < 0 ? 0 : (temp > n / 2 ? n / 2 : temp);
1014
-
1015
- std::nth_element(x_copy.begin(), x_copy.begin() + o, x_copy.end());
1016
- vmin = x_copy[o];
1017
- std::nth_element(
1018
- x_copy.begin(), x_copy.begin() + (n - 1 - o), x_copy.end());
1019
- vmax = x_copy[n - 1 - o];
1020
-
1021
- } else if (rs == ScalarQuantizer::RS_optim) {
1022
- float a, b;
1023
- float sx = 0;
1024
- {
1025
- vmin = HUGE_VAL, vmax = -HUGE_VAL;
1026
- for (size_t i = 0; i < n; i++) {
1027
- if (x[i] < vmin) {
1028
- vmin = x[i];
1029
- }
1030
- if (x[i] > vmax) {
1031
- vmax = x[i];
1032
- }
1033
- sx += x[i];
1034
- }
1035
- b = vmin;
1036
- a = (vmax - vmin) / (k - 1);
1037
- }
1038
- int verbose = false;
1039
- int niter = 2000;
1040
- float last_err = -1;
1041
- int iter_last_err = 0;
1042
- for (int it = 0; it < niter; it++) {
1043
- float sn = 0, sn2 = 0, sxn = 0, err1 = 0;
1044
-
1045
- for (idx_t i = 0; i < n; i++) {
1046
- float xi = x[i];
1047
- float ni = floor((xi - b) / a + 0.5);
1048
- if (ni < 0) {
1049
- ni = 0;
1050
- }
1051
- if (ni >= k) {
1052
- ni = k - 1;
1053
- }
1054
- err1 += sqr(xi - (ni * a + b));
1055
- sn += ni;
1056
- sn2 += ni * ni;
1057
- sxn += ni * xi;
1058
- }
1059
-
1060
- if (err1 == last_err) {
1061
- iter_last_err++;
1062
- if (iter_last_err == 16) {
1063
- break;
1064
- }
1065
- } else {
1066
- last_err = err1;
1067
- iter_last_err = 0;
1068
- }
1069
-
1070
- float det = sqr(sn) - sn2 * n;
1071
-
1072
- b = (sn * sxn - sn2 * sx) / det;
1073
- a = (sn * sx - n * sxn) / det;
1074
- if (verbose) {
1075
- printf("it %d, err1=%g \r", it, err1);
1076
- fflush(stdout);
1077
- }
1078
- }
1079
- if (verbose) {
1080
- printf("\n");
1081
- }
1082
-
1083
- vmin = b;
1084
- vmax = b + a * (k - 1);
1085
-
1086
- } else {
1087
- FAISS_THROW_MSG("Invalid qtype");
1088
- }
1089
- vmax -= vmin;
1090
- }
1091
-
1092
- void train_NonUniform(
1093
- RangeStat rs,
1094
- float rs_arg,
1095
- idx_t n,
1096
- int d,
1097
- int k,
1098
- const float* x,
1099
- std::vector<float>& trained) {
1100
- trained.resize(2 * d);
1101
- float* vmin = trained.data();
1102
- float* vmax = trained.data() + d;
1103
- if (rs == ScalarQuantizer::RS_minmax) {
1104
- memcpy(vmin, x, sizeof(*x) * d);
1105
- memcpy(vmax, x, sizeof(*x) * d);
1106
- for (size_t i = 1; i < n; i++) {
1107
- const float* xi = x + i * d;
1108
- for (size_t j = 0; j < d; j++) {
1109
- if (xi[j] < vmin[j]) {
1110
- vmin[j] = xi[j];
1111
- }
1112
- if (xi[j] > vmax[j]) {
1113
- vmax[j] = xi[j];
1114
- }
1115
- }
1116
- }
1117
- float* vdiff = vmax;
1118
- for (size_t j = 0; j < d; j++) {
1119
- float vexp = (vmax[j] - vmin[j]) * rs_arg;
1120
- vmin[j] -= vexp;
1121
- vmax[j] += vexp;
1122
- vdiff[j] = vmax[j] - vmin[j];
1123
- }
1124
- } else {
1125
- // transpose
1126
- std::vector<float> xt(n * d);
1127
- for (size_t i = 1; i < n; i++) {
1128
- const float* xi = x + i * d;
1129
- for (size_t j = 0; j < d; j++) {
1130
- xt[j * n + i] = xi[j];
1131
- }
1132
- }
1133
- std::vector<float> trained_d(2);
1134
- #pragma omp parallel for
1135
- for (int j = 0; j < d; j++) {
1136
- train_Uniform(rs, rs_arg, n, k, xt.data() + j * n, trained_d);
1137
- vmin[j] = trained_d[0];
1138
- vmax[j] = trained_d[1];
1139
- }
1140
- }
1141
- }
1142
-
1143
- /*******************************************************************
1144
- * Similarity: gets vector components and computes a similarity wrt. a
1145
- * query vector stored in the object. The data fields just encapsulate
1146
- * an accumulator.
1147
- */
1148
-
1149
- template <int SIMDWIDTH>
1150
- struct SimilarityL2 {};
1151
-
1152
- template <>
1153
- struct SimilarityL2<1> {
1154
- static constexpr int simdwidth = 1;
1155
- static constexpr MetricType metric_type = METRIC_L2;
1156
-
1157
- const float *y, *yi;
1158
-
1159
- explicit SimilarityL2(const float* y) : y(y) {}
1160
-
1161
- /******* scalar accumulator *******/
1162
-
1163
- float accu;
1164
-
1165
- FAISS_ALWAYS_INLINE void begin() {
1166
- accu = 0;
1167
- yi = y;
1168
- }
1169
-
1170
- FAISS_ALWAYS_INLINE void add_component(float x) {
1171
- float tmp = *yi++ - x;
1172
- accu += tmp * tmp;
1173
- }
1174
-
1175
- FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) {
1176
- float tmp = x1 - x2;
1177
- accu += tmp * tmp;
1178
- }
1179
-
1180
- FAISS_ALWAYS_INLINE float result() {
1181
- return accu;
1182
- }
1183
- };
1184
-
1185
- #if defined(__AVX512F__)
1186
-
1187
- template <>
1188
- struct SimilarityL2<16> {
1189
- static constexpr int simdwidth = 16;
1190
- static constexpr MetricType metric_type = METRIC_L2;
1191
-
1192
- const float *y, *yi;
1193
-
1194
- explicit SimilarityL2(const float* y) : y(y) {}
1195
- __m512 accu16;
1196
-
1197
- FAISS_ALWAYS_INLINE void begin_16() {
1198
- accu16 = _mm512_setzero_ps();
1199
- yi = y;
1200
- }
1201
-
1202
- FAISS_ALWAYS_INLINE void add_16_components(__m512 x) {
1203
- __m512 yiv = _mm512_loadu_ps(yi);
1204
- yi += 16;
1205
- __m512 tmp = _mm512_sub_ps(yiv, x);
1206
- accu16 = _mm512_fmadd_ps(tmp, tmp, accu16);
1207
- }
1208
-
1209
- FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x, __m512 y_2) {
1210
- __m512 tmp = _mm512_sub_ps(y_2, x);
1211
- accu16 = _mm512_fmadd_ps(tmp, tmp, accu16);
1212
- }
1213
-
1214
- FAISS_ALWAYS_INLINE float result_16() {
1215
- // performs better than dividing into _mm256 and adding
1216
- return _mm512_reduce_add_ps(accu16);
1217
- }
1218
- };
1219
-
1220
- #elif defined(__AVX2__)
1221
-
1222
- template <>
1223
- struct SimilarityL2<8> {
1224
- static constexpr int simdwidth = 8;
1225
- static constexpr MetricType metric_type = METRIC_L2;
1226
-
1227
- const float *y, *yi;
1228
-
1229
- explicit SimilarityL2(const float* y) : y(y) {}
1230
- __m256 accu8;
1231
-
1232
- FAISS_ALWAYS_INLINE void begin_8() {
1233
- accu8 = _mm256_setzero_ps();
1234
- yi = y;
1235
- }
1236
-
1237
- FAISS_ALWAYS_INLINE void add_8_components(__m256 x) {
1238
- __m256 yiv = _mm256_loadu_ps(yi);
1239
- yi += 8;
1240
- __m256 tmp = _mm256_sub_ps(yiv, x);
1241
- accu8 = _mm256_fmadd_ps(tmp, tmp, accu8);
1242
- }
1243
-
1244
- FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x, __m256 y_2) {
1245
- __m256 tmp = _mm256_sub_ps(y_2, x);
1246
- accu8 = _mm256_fmadd_ps(tmp, tmp, accu8);
1247
- }
1248
-
1249
- FAISS_ALWAYS_INLINE float result_8() {
1250
- const __m128 sum = _mm_add_ps(
1251
- _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1));
1252
- const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2));
1253
- const __m128 v1 = _mm_add_ps(sum, v0);
1254
- __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
1255
- const __m128 v3 = _mm_add_ps(v1, v2);
1256
- return _mm_cvtss_f32(v3);
1257
- }
1258
- };
1259
-
1260
- #endif
1261
-
1262
- #ifdef USE_NEON
1263
- template <>
1264
- struct SimilarityL2<8> {
1265
- static constexpr int simdwidth = 8;
1266
- static constexpr MetricType metric_type = METRIC_L2;
1267
-
1268
- const float *y, *yi;
1269
- explicit SimilarityL2(const float* y) : y(y) {}
1270
- float32x4x2_t accu8;
1271
-
1272
- FAISS_ALWAYS_INLINE void begin_8() {
1273
- accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
1274
- yi = y;
1275
- }
1276
-
1277
- FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
1278
- float32x4x2_t yiv = vld1q_f32_x2(yi);
1279
- yi += 8;
1280
-
1281
- float32x4_t sub0 = vsubq_f32(yiv.val[0], x.val[0]);
1282
- float32x4_t sub1 = vsubq_f32(yiv.val[1], x.val[1]);
1283
-
1284
- float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
1285
- float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
1286
-
1287
- accu8 = {accu8_0, accu8_1};
1288
- }
1289
-
1290
- FAISS_ALWAYS_INLINE void add_8_components_2(
1291
- float32x4x2_t x,
1292
- float32x4x2_t y) {
1293
- float32x4_t sub0 = vsubq_f32(y.val[0], x.val[0]);
1294
- float32x4_t sub1 = vsubq_f32(y.val[1], x.val[1]);
1295
-
1296
- float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], sub0, sub0);
1297
- float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], sub1, sub1);
1298
-
1299
- accu8 = {accu8_0, accu8_1};
1300
- }
1301
-
1302
- FAISS_ALWAYS_INLINE float result_8() {
1303
- float32x4_t sum_0 = vpaddq_f32(accu8.val[0], accu8.val[0]);
1304
- float32x4_t sum_1 = vpaddq_f32(accu8.val[1], accu8.val[1]);
1305
-
1306
- float32x4_t sum2_0 = vpaddq_f32(sum_0, sum_0);
1307
- float32x4_t sum2_1 = vpaddq_f32(sum_1, sum_1);
1308
- return vgetq_lane_f32(sum2_0, 0) + vgetq_lane_f32(sum2_1, 0);
1309
- }
1310
- };
1311
- #endif
1312
-
1313
- template <int SIMDWIDTH>
1314
- struct SimilarityIP {};
1315
-
1316
- template <>
1317
- struct SimilarityIP<1> {
1318
- static constexpr int simdwidth = 1;
1319
- static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
1320
- const float *y, *yi;
1321
-
1322
- float accu;
1323
-
1324
- explicit SimilarityIP(const float* y) : y(y) {}
1325
-
1326
- FAISS_ALWAYS_INLINE void begin() {
1327
- accu = 0;
1328
- yi = y;
1329
- }
1330
-
1331
- FAISS_ALWAYS_INLINE void add_component(float x) {
1332
- accu += *yi++ * x;
1333
- }
1334
-
1335
- FAISS_ALWAYS_INLINE void add_component_2(float x1, float x2) {
1336
- accu += x1 * x2;
1337
- }
1338
-
1339
- FAISS_ALWAYS_INLINE float result() {
1340
- return accu;
1341
- }
1342
- };
1343
-
1344
- #if defined(__AVX512F__)
1345
-
1346
- template <>
1347
- struct SimilarityIP<16> {
1348
- static constexpr int simdwidth = 16;
1349
- static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
1350
-
1351
- const float *y, *yi;
1352
-
1353
- float accu;
1354
-
1355
- explicit SimilarityIP(const float* y) : y(y) {}
1356
-
1357
- __m512 accu16;
1358
-
1359
- FAISS_ALWAYS_INLINE void begin_16() {
1360
- accu16 = _mm512_setzero_ps();
1361
- yi = y;
1362
- }
1363
-
1364
- FAISS_ALWAYS_INLINE void add_16_components(__m512 x) {
1365
- __m512 yiv = _mm512_loadu_ps(yi);
1366
- yi += 16;
1367
- accu16 = _mm512_fmadd_ps(yiv, x, accu16);
1368
- }
1369
-
1370
- FAISS_ALWAYS_INLINE void add_16_components_2(__m512 x1, __m512 x2) {
1371
- accu16 = _mm512_fmadd_ps(x1, x2, accu16);
1372
- }
1373
-
1374
- FAISS_ALWAYS_INLINE float result_16() {
1375
- // performs better than dividing into _mm256 and adding
1376
- return _mm512_reduce_add_ps(accu16);
1377
- }
1378
- };
1379
-
1380
- #elif defined(__AVX2__)
1381
-
1382
- template <>
1383
- struct SimilarityIP<8> {
1384
- static constexpr int simdwidth = 8;
1385
- static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
1386
-
1387
- const float *y, *yi;
1388
-
1389
- float accu;
1390
-
1391
- explicit SimilarityIP(const float* y) : y(y) {}
1392
-
1393
- __m256 accu8;
1394
-
1395
- FAISS_ALWAYS_INLINE void begin_8() {
1396
- accu8 = _mm256_setzero_ps();
1397
- yi = y;
1398
- }
1399
-
1400
- FAISS_ALWAYS_INLINE void add_8_components(__m256 x) {
1401
- __m256 yiv = _mm256_loadu_ps(yi);
1402
- yi += 8;
1403
- accu8 = _mm256_fmadd_ps(yiv, x, accu8);
1404
- }
1405
-
1406
- FAISS_ALWAYS_INLINE void add_8_components_2(__m256 x1, __m256 x2) {
1407
- accu8 = _mm256_fmadd_ps(x1, x2, accu8);
1408
- }
1409
-
1410
- FAISS_ALWAYS_INLINE float result_8() {
1411
- const __m128 sum = _mm_add_ps(
1412
- _mm256_castps256_ps128(accu8), _mm256_extractf128_ps(accu8, 1));
1413
- const __m128 v0 = _mm_shuffle_ps(sum, sum, _MM_SHUFFLE(0, 0, 3, 2));
1414
- const __m128 v1 = _mm_add_ps(sum, v0);
1415
- __m128 v2 = _mm_shuffle_ps(v1, v1, _MM_SHUFFLE(0, 0, 0, 1));
1416
- const __m128 v3 = _mm_add_ps(v1, v2);
1417
- return _mm_cvtss_f32(v3);
1418
- }
1419
- };
1420
- #endif
1421
-
1422
- #ifdef USE_NEON
1423
-
1424
- template <>
1425
- struct SimilarityIP<8> {
1426
- static constexpr int simdwidth = 8;
1427
- static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
1428
-
1429
- const float *y, *yi;
1430
-
1431
- explicit SimilarityIP(const float* y) : y(y) {}
1432
- float32x4x2_t accu8;
1433
-
1434
- FAISS_ALWAYS_INLINE void begin_8() {
1435
- accu8 = {vdupq_n_f32(0.0f), vdupq_n_f32(0.0f)};
1436
- yi = y;
1437
- }
1438
-
1439
- FAISS_ALWAYS_INLINE void add_8_components(float32x4x2_t x) {
1440
- float32x4x2_t yiv = vld1q_f32_x2(yi);
1441
- yi += 8;
1442
-
1443
- float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], yiv.val[0], x.val[0]);
1444
- float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], yiv.val[1], x.val[1]);
1445
- accu8 = {accu8_0, accu8_1};
1446
- }
1447
-
1448
- FAISS_ALWAYS_INLINE void add_8_components_2(
1449
- float32x4x2_t x1,
1450
- float32x4x2_t x2) {
1451
- float32x4_t accu8_0 = vfmaq_f32(accu8.val[0], x1.val[0], x2.val[0]);
1452
- float32x4_t accu8_1 = vfmaq_f32(accu8.val[1], x1.val[1], x2.val[1]);
1453
- accu8 = {accu8_0, accu8_1};
1454
- }
1455
-
1456
- FAISS_ALWAYS_INLINE float result_8() {
1457
- float32x4x2_t sum = {
1458
- vpaddq_f32(accu8.val[0], accu8.val[0]),
1459
- vpaddq_f32(accu8.val[1], accu8.val[1])};
1460
-
1461
- float32x4x2_t sum2 = {
1462
- vpaddq_f32(sum.val[0], sum.val[0]),
1463
- vpaddq_f32(sum.val[1], sum.val[1])};
1464
- return vgetq_lane_f32(sum2.val[0], 0) + vgetq_lane_f32(sum2.val[1], 0);
1465
- }
1466
- };
1467
- #endif
1468
-
1469
- /*******************************************************************
1470
- * DistanceComputer: combines a similarity and a quantizer to do
1471
- * code-to-vector or code-to-code comparisons
1472
- *******************************************************************/
1473
-
1474
- template <class Quantizer, class Similarity, int SIMDWIDTH>
1475
- struct DCTemplate : SQDistanceComputer {};
1476
-
1477
- template <class Quantizer, class Similarity>
1478
- struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
1479
- using Sim = Similarity;
10
+ #include <cstring>
11
+ #include <memory>
1480
12
 
1481
- Quantizer quant;
1482
-
1483
- DCTemplate(size_t d, const std::vector<float>& trained)
1484
- : quant(d, trained) {}
1485
-
1486
- float compute_distance(const float* x, const uint8_t* code) const {
1487
- Similarity sim(x);
1488
- sim.begin();
1489
- for (size_t i = 0; i < quant.d; i++) {
1490
- float xi = quant.reconstruct_component(code, i);
1491
- sim.add_component(xi);
1492
- }
1493
- return sim.result();
1494
- }
1495
-
1496
- float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1497
- const {
1498
- Similarity sim(nullptr);
1499
- sim.begin();
1500
- for (size_t i = 0; i < quant.d; i++) {
1501
- float x1 = quant.reconstruct_component(code1, i);
1502
- float x2 = quant.reconstruct_component(code2, i);
1503
- sim.add_component_2(x1, x2);
1504
- }
1505
- return sim.result();
1506
- }
1507
-
1508
- void set_query(const float* x) final {
1509
- q = x;
1510
- }
1511
-
1512
- float symmetric_dis(idx_t i, idx_t j) override {
1513
- return compute_code_distance(
1514
- codes + i * code_size, codes + j * code_size);
1515
- }
1516
-
1517
- float query_to_code(const uint8_t* code) const final {
1518
- return compute_distance(q, code);
1519
- }
1520
- };
1521
-
1522
- #if defined(USE_AVX512_F16C)
1523
-
1524
- template <class Quantizer, class Similarity>
1525
- struct DCTemplate<Quantizer, Similarity, 16>
1526
- : SQDistanceComputer { // Update to handle 16 lanes
1527
- using Sim = Similarity;
1528
-
1529
- Quantizer quant;
1530
-
1531
- DCTemplate(size_t d, const std::vector<float>& trained)
1532
- : quant(d, trained) {}
1533
-
1534
- float compute_distance(const float* x, const uint8_t* code) const {
1535
- Similarity sim(x);
1536
- sim.begin_16();
1537
- for (size_t i = 0; i < quant.d; i += 16) {
1538
- __m512 xi = quant.reconstruct_16_components(code, i);
1539
- sim.add_16_components(xi);
1540
- }
1541
- return sim.result_16();
1542
- }
1543
-
1544
- float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1545
- const {
1546
- Similarity sim(nullptr);
1547
- sim.begin_16();
1548
- for (size_t i = 0; i < quant.d; i += 16) {
1549
- __m512 x1 = quant.reconstruct_16_components(code1, i);
1550
- __m512 x2 = quant.reconstruct_16_components(code2, i);
1551
- sim.add_16_components_2(x1, x2);
1552
- }
1553
- return sim.result_16();
1554
- }
1555
-
1556
- void set_query(const float* x) final {
1557
- q = x;
1558
- }
1559
-
1560
- float symmetric_dis(idx_t i, idx_t j) override {
1561
- return compute_code_distance(
1562
- codes + i * code_size, codes + j * code_size);
1563
- }
1564
-
1565
- float query_to_code(const uint8_t* code) const final {
1566
- return compute_distance(q, code);
1567
- }
1568
- };
1569
-
1570
- #elif defined(USE_F16C)
1571
-
1572
- template <class Quantizer, class Similarity>
1573
- struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
1574
- using Sim = Similarity;
1575
-
1576
- Quantizer quant;
1577
-
1578
- DCTemplate(size_t d, const std::vector<float>& trained)
1579
- : quant(d, trained) {}
1580
-
1581
- float compute_distance(const float* x, const uint8_t* code) const {
1582
- Similarity sim(x);
1583
- sim.begin_8();
1584
- for (size_t i = 0; i < quant.d; i += 8) {
1585
- __m256 xi = quant.reconstruct_8_components(code, i);
1586
- sim.add_8_components(xi);
1587
- }
1588
- return sim.result_8();
1589
- }
1590
-
1591
- float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1592
- const {
1593
- Similarity sim(nullptr);
1594
- sim.begin_8();
1595
- for (size_t i = 0; i < quant.d; i += 8) {
1596
- __m256 x1 = quant.reconstruct_8_components(code1, i);
1597
- __m256 x2 = quant.reconstruct_8_components(code2, i);
1598
- sim.add_8_components_2(x1, x2);
1599
- }
1600
- return sim.result_8();
1601
- }
1602
-
1603
- void set_query(const float* x) final {
1604
- q = x;
1605
- }
1606
-
1607
- float symmetric_dis(idx_t i, idx_t j) override {
1608
- return compute_code_distance(
1609
- codes + i * code_size, codes + j * code_size);
1610
- }
1611
-
1612
- float query_to_code(const uint8_t* code) const final {
1613
- return compute_distance(q, code);
1614
- }
1615
- };
1616
-
1617
- #endif
1618
-
1619
- #ifdef USE_NEON
1620
-
1621
- template <class Quantizer, class Similarity>
1622
- struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
1623
- using Sim = Similarity;
1624
-
1625
- Quantizer quant;
1626
-
1627
- DCTemplate(size_t d, const std::vector<float>& trained)
1628
- : quant(d, trained) {}
1629
- float compute_distance(const float* x, const uint8_t* code) const {
1630
- Similarity sim(x);
1631
- sim.begin_8();
1632
- for (size_t i = 0; i < quant.d; i += 8) {
1633
- float32x4x2_t xi = quant.reconstruct_8_components(code, i);
1634
- sim.add_8_components(xi);
1635
- }
1636
- return sim.result_8();
1637
- }
1638
-
1639
- float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1640
- const {
1641
- Similarity sim(nullptr);
1642
- sim.begin_8();
1643
- for (size_t i = 0; i < quant.d; i += 8) {
1644
- float32x4x2_t x1 = quant.reconstruct_8_components(code1, i);
1645
- float32x4x2_t x2 = quant.reconstruct_8_components(code2, i);
1646
- sim.add_8_components_2(x1, x2);
1647
- }
1648
- return sim.result_8();
1649
- }
1650
-
1651
- void set_query(const float* x) final {
1652
- q = x;
1653
- }
1654
-
1655
- float symmetric_dis(idx_t i, idx_t j) override {
1656
- return compute_code_distance(
1657
- codes + i * code_size, codes + j * code_size);
1658
- }
1659
-
1660
- float query_to_code(const uint8_t* code) const final {
1661
- return compute_distance(q, code);
1662
- }
1663
- };
1664
- #endif
1665
-
1666
- /*******************************************************************
1667
- * DistanceComputerByte: computes distances in the integer domain
1668
- *******************************************************************/
1669
-
1670
- template <class Similarity, int SIMDWIDTH>
1671
- struct DistanceComputerByte : SQDistanceComputer {};
1672
-
1673
- template <class Similarity>
1674
- struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1675
- using Sim = Similarity;
1676
-
1677
- int d;
1678
- std::vector<uint8_t> tmp;
1679
-
1680
- DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1681
-
1682
- int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1683
- const {
1684
- int accu = 0;
1685
- for (int i = 0; i < d; i++) {
1686
- if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1687
- accu += int(code1[i]) * code2[i];
1688
- } else {
1689
- int diff = int(code1[i]) - code2[i];
1690
- accu += diff * diff;
1691
- }
1692
- }
1693
- return accu;
1694
- }
1695
-
1696
- void set_query(const float* x) final {
1697
- for (int i = 0; i < d; i++) {
1698
- tmp[i] = int(x[i]);
1699
- }
1700
- }
1701
-
1702
- int compute_distance(const float* x, const uint8_t* code) {
1703
- set_query(x);
1704
- return compute_code_distance(tmp.data(), code);
1705
- }
1706
-
1707
- float symmetric_dis(idx_t i, idx_t j) override {
1708
- return compute_code_distance(
1709
- codes + i * code_size, codes + j * code_size);
1710
- }
1711
-
1712
- float query_to_code(const uint8_t* code) const final {
1713
- return compute_code_distance(tmp.data(), code);
1714
- }
1715
- };
1716
-
1717
- #if defined(__AVX512F__)
1718
-
1719
- template <class Similarity>
1720
- struct DistanceComputerByte<Similarity, 16> : SQDistanceComputer {
1721
- using Sim = Similarity;
1722
-
1723
- int d;
1724
- std::vector<uint8_t> tmp;
1725
-
1726
- DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1727
-
1728
- int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1729
- const {
1730
- __m512i accu = _mm512_setzero_si512();
1731
- for (int i = 0; i < d; i += 32) { // Process 32 bytes at a time
1732
- __m512i c1 = _mm512_cvtepu8_epi16(
1733
- _mm256_loadu_si256((__m256i*)(code1 + i)));
1734
- __m512i c2 = _mm512_cvtepu8_epi16(
1735
- _mm256_loadu_si256((__m256i*)(code2 + i)));
1736
- __m512i prod32;
1737
- if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1738
- prod32 = _mm512_madd_epi16(c1, c2);
1739
- } else {
1740
- __m512i diff = _mm512_sub_epi16(c1, c2);
1741
- prod32 = _mm512_madd_epi16(diff, diff);
1742
- }
1743
- accu = _mm512_add_epi32(accu, prod32);
1744
- }
1745
- // Horizontally add elements of accu
1746
- return _mm512_reduce_add_epi32(accu);
1747
- }
1748
-
1749
- void set_query(const float* x) final {
1750
- for (int i = 0; i < d; i++) {
1751
- tmp[i] = int(x[i]);
1752
- }
1753
- }
1754
-
1755
- int compute_distance(const float* x, const uint8_t* code) {
1756
- set_query(x);
1757
- return compute_code_distance(tmp.data(), code);
1758
- }
1759
-
1760
- float symmetric_dis(idx_t i, idx_t j) override {
1761
- return compute_code_distance(
1762
- codes + i * code_size, codes + j * code_size);
1763
- }
1764
-
1765
- float query_to_code(const uint8_t* code) const final {
1766
- return compute_code_distance(tmp.data(), code);
1767
- }
1768
- };
1769
-
1770
- #elif defined(__AVX2__)
1771
-
1772
- template <class Similarity>
1773
- struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1774
- using Sim = Similarity;
1775
-
1776
- int d;
1777
- std::vector<uint8_t> tmp;
1778
-
1779
- DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1780
-
1781
- int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1782
- const {
1783
- // __m256i accu = _mm256_setzero_ps ();
1784
- __m256i accu = _mm256_setzero_si256();
1785
- for (int i = 0; i < d; i += 16) {
1786
- // load 16 bytes, convert to 16 uint16_t
1787
- __m256i c1 = _mm256_cvtepu8_epi16(
1788
- _mm_loadu_si128((__m128i*)(code1 + i)));
1789
- __m256i c2 = _mm256_cvtepu8_epi16(
1790
- _mm_loadu_si128((__m128i*)(code2 + i)));
1791
- __m256i prod32;
1792
- if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1793
- prod32 = _mm256_madd_epi16(c1, c2);
1794
- } else {
1795
- __m256i diff = _mm256_sub_epi16(c1, c2);
1796
- prod32 = _mm256_madd_epi16(diff, diff);
1797
- }
1798
- accu = _mm256_add_epi32(accu, prod32);
1799
- }
1800
- __m128i sum = _mm256_extractf128_si256(accu, 0);
1801
- sum = _mm_add_epi32(sum, _mm256_extractf128_si256(accu, 1));
1802
- sum = _mm_hadd_epi32(sum, sum);
1803
- sum = _mm_hadd_epi32(sum, sum);
1804
- return _mm_cvtsi128_si32(sum);
1805
- }
1806
-
1807
- void set_query(const float* x) final {
1808
- /*
1809
- for (int i = 0; i < d; i += 8) {
1810
- __m256 xi = _mm256_loadu_ps (x + i);
1811
- __m256i ci = _mm256_cvtps_epi32(xi);
1812
- */
1813
- for (int i = 0; i < d; i++) {
1814
- tmp[i] = int(x[i]);
1815
- }
1816
- }
1817
-
1818
- int compute_distance(const float* x, const uint8_t* code) {
1819
- set_query(x);
1820
- return compute_code_distance(tmp.data(), code);
1821
- }
1822
-
1823
- float symmetric_dis(idx_t i, idx_t j) override {
1824
- return compute_code_distance(
1825
- codes + i * code_size, codes + j * code_size);
1826
- }
1827
-
1828
- float query_to_code(const uint8_t* code) const final {
1829
- return compute_code_distance(tmp.data(), code);
1830
- }
1831
- };
1832
-
1833
- #endif
1834
-
1835
- #ifdef USE_NEON
1836
-
1837
- template <class Similarity>
1838
- struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1839
- using Sim = Similarity;
1840
-
1841
- int d;
1842
- std::vector<uint8_t> tmp;
1843
-
1844
- DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1845
-
1846
- int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1847
- const {
1848
- int accu = 0;
1849
- for (int i = 0; i < d; i++) {
1850
- if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1851
- accu += int(code1[i]) * code2[i];
1852
- } else {
1853
- int diff = int(code1[i]) - code2[i];
1854
- accu += diff * diff;
1855
- }
1856
- }
1857
- return accu;
1858
- }
1859
-
1860
- void set_query(const float* x) final {
1861
- for (int i = 0; i < d; i++) {
1862
- tmp[i] = int(x[i]);
1863
- }
1864
- }
1865
-
1866
- int compute_distance(const float* x, const uint8_t* code) {
1867
- set_query(x);
1868
- return compute_code_distance(tmp.data(), code);
1869
- }
1870
-
1871
- float symmetric_dis(idx_t i, idx_t j) override {
1872
- return compute_code_distance(
1873
- codes + i * code_size, codes + j * code_size);
1874
- }
1875
-
1876
- float query_to_code(const uint8_t* code) const final {
1877
- return compute_code_distance(tmp.data(), code);
1878
- }
1879
- };
1880
-
1881
- #endif
1882
-
1883
- /*******************************************************************
1884
- * select_distance_computer: runtime selection of template
1885
- * specialization
1886
- *******************************************************************/
1887
-
1888
- template <class Sim>
1889
- SQDistanceComputer* select_distance_computer(
1890
- QuantizerType qtype,
1891
- size_t d,
1892
- const std::vector<float>& trained) {
1893
- constexpr int SIMDWIDTH = Sim::simdwidth;
1894
- switch (qtype) {
1895
- case ScalarQuantizer::QT_8bit_uniform:
1896
- return new DCTemplate<
1897
- QuantizerTemplate<
1898
- Codec8bit,
1899
- QuantizerTemplateScaling::UNIFORM,
1900
- SIMDWIDTH>,
1901
- Sim,
1902
- SIMDWIDTH>(d, trained);
1903
-
1904
- case ScalarQuantizer::QT_4bit_uniform:
1905
- return new DCTemplate<
1906
- QuantizerTemplate<
1907
- Codec4bit,
1908
- QuantizerTemplateScaling::UNIFORM,
1909
- SIMDWIDTH>,
1910
- Sim,
1911
- SIMDWIDTH>(d, trained);
1912
-
1913
- case ScalarQuantizer::QT_8bit:
1914
- return new DCTemplate<
1915
- QuantizerTemplate<
1916
- Codec8bit,
1917
- QuantizerTemplateScaling::NON_UNIFORM,
1918
- SIMDWIDTH>,
1919
- Sim,
1920
- SIMDWIDTH>(d, trained);
1921
-
1922
- case ScalarQuantizer::QT_6bit:
1923
- return new DCTemplate<
1924
- QuantizerTemplate<
1925
- Codec6bit,
1926
- QuantizerTemplateScaling::NON_UNIFORM,
1927
- SIMDWIDTH>,
1928
- Sim,
1929
- SIMDWIDTH>(d, trained);
13
+ #include <faiss/impl/ScalarQuantizer.h>
14
+ #include <faiss/utils/simd_levels.h>
1930
15
 
1931
- case ScalarQuantizer::QT_4bit:
1932
- return new DCTemplate<
1933
- QuantizerTemplate<
1934
- Codec4bit,
1935
- QuantizerTemplateScaling::NON_UNIFORM,
1936
- SIMDWIDTH>,
1937
- Sim,
1938
- SIMDWIDTH>(d, trained);
16
+ #include <faiss/impl/scalar_quantizer/training.h>
1939
17
 
1940
- case ScalarQuantizer::QT_fp16:
1941
- return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1942
- d, trained);
18
+ #include <faiss/impl/FaissAssert.h>
19
+ #include <faiss/impl/simd_dispatch.h>
1943
20
 
1944
- case ScalarQuantizer::QT_bf16:
1945
- return new DCTemplate<QuantizerBF16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1946
- d, trained);
21
+ #include <faiss/impl/scalar_quantizer/scanners.h>
1947
22
 
1948
- case ScalarQuantizer::QT_8bit_direct:
1949
- #if defined(__AVX512F__)
1950
- if (d % 32 == 0) {
1951
- return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1952
- } else
1953
- #elif defined(__AVX2__)
1954
- if (d % 16 == 0) {
1955
- return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1956
- } else
1957
- #endif
1958
- {
1959
- return new DCTemplate<
1960
- Quantizer8bitDirect<SIMDWIDTH>,
1961
- Sim,
1962
- SIMDWIDTH>(d, trained);
1963
- }
1964
- case ScalarQuantizer::QT_8bit_direct_signed:
1965
- return new DCTemplate<
1966
- Quantizer8bitDirectSigned<SIMDWIDTH>,
1967
- Sim,
1968
- SIMDWIDTH>(d, trained);
1969
- }
1970
- FAISS_THROW_MSG("unknown qtype");
1971
- return nullptr;
1972
- }
23
+ #define THE_LEVEL_TO_DISPATCH SIMDLevel::NONE
24
+ #include <faiss/impl/scalar_quantizer/sq-dispatch.h>
1973
25
 
1974
- } // anonymous namespace
26
+ namespace faiss {
1975
27
 
1976
28
  /*******************************************************************
1977
29
  * ScalarQuantizer implementation
@@ -2010,10 +62,15 @@ void ScalarQuantizer::set_derived_sizes() {
2010
62
  code_size = d * 2;
2011
63
  bits = 16;
2012
64
  break;
65
+ default:
66
+ break;
2013
67
  }
2014
68
  }
2015
69
 
2016
70
  void ScalarQuantizer::train(size_t n, const float* x) {
71
+ using scalar_quantizer::train_NonUniform;
72
+ using scalar_quantizer::train_Uniform;
73
+
2017
74
  int bit_per_dim = qtype == QT_4bit_uniform ? 4
2018
75
  : qtype == QT_4bit ? 4
2019
76
  : qtype == QT_6bit ? 6
@@ -2039,7 +96,7 @@ void ScalarQuantizer::train(size_t n, const float* x) {
2039
96
  rangestat,
2040
97
  rangestat_arg,
2041
98
  n,
2042
- d,
99
+ int(d),
2043
100
  1 << bit_per_dim,
2044
101
  x,
2045
102
  trained);
@@ -2050,22 +107,23 @@ void ScalarQuantizer::train(size_t n, const float* x) {
2050
107
  case QT_8bit_direct_signed:
2051
108
  // no training necessary
2052
109
  break;
110
+ default:
111
+ break;
2053
112
  }
2054
113
  }
2055
114
 
2056
115
  ScalarQuantizer::SQuantizer* ScalarQuantizer::select_quantizer() const {
2057
- #if defined(USE_AVX512_F16C)
2058
- if (d % 16 == 0) {
2059
- return select_quantizer_1<16>(qtype, d, trained);
2060
- } else
2061
- #elif defined(USE_F16C) || defined(USE_NEON)
2062
- if (d % 8 == 0) {
2063
- return select_quantizer_1<8>(qtype, d, trained);
2064
- } else
2065
- #endif
2066
- {
2067
- return select_quantizer_1<1>(qtype, d, trained);
2068
- }
116
+ return with_simd_level([&]<SIMDLevel SL>() -> SQuantizer* {
117
+ if constexpr (SL != SIMDLevel::NONE) {
118
+ auto* q = scalar_quantizer::sq_select_quantizer<SL>(
119
+ qtype, d, trained);
120
+ if (q) {
121
+ return q;
122
+ }
123
+ }
124
+ return scalar_quantizer::sq_select_quantizer<SIMDLevel::NONE>(
125
+ qtype, d, trained);
126
+ });
2069
127
  }
2070
128
 
2071
129
  void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
@@ -2088,404 +146,55 @@ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
2088
146
  }
2089
147
  }
2090
148
 
2091
- SQDistanceComputer* ScalarQuantizer::get_distance_computer(
149
+ ScalarQuantizer::SQDistanceComputer* ScalarQuantizer::get_distance_computer(
2092
150
  MetricType metric) const {
2093
151
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
2094
- #if defined(USE_AVX512_F16C)
2095
- if (d % 16 == 0) {
2096
- if (metric == METRIC_L2) {
2097
- return select_distance_computer<SimilarityL2<16>>(
2098
- qtype, d, trained);
2099
- } else {
2100
- return select_distance_computer<SimilarityIP<16>>(
2101
- qtype, d, trained);
2102
- }
2103
- } else
2104
- #elif defined(USE_F16C) || defined(USE_NEON)
2105
- if (d % 8 == 0) {
2106
- if (metric == METRIC_L2) {
2107
- return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
2108
- } else {
2109
- return select_distance_computer<SimilarityIP<8>>(qtype, d, trained);
2110
- }
2111
- } else
2112
- #endif
2113
- {
2114
- if (metric == METRIC_L2) {
2115
- return select_distance_computer<SimilarityL2<1>>(qtype, d, trained);
2116
- } else {
2117
- return select_distance_computer<SimilarityIP<1>>(qtype, d, trained);
2118
- }
2119
- }
2120
- }
2121
-
2122
- /*******************************************************************
2123
- * IndexScalarQuantizer/IndexIVFScalarQuantizer scanner object
2124
- *
2125
- * It is an InvertedListScanner, but is designed to work with
2126
- * IndexScalarQuantizer as well.
2127
- ********************************************************************/
2128
-
2129
- namespace {
2130
-
2131
- template <class DCClass, int use_sel>
2132
- struct IVFSQScannerIP : InvertedListScanner {
2133
- DCClass dc;
2134
- bool by_residual;
2135
-
2136
- float accu0; /// added to all distances
2137
-
2138
- IVFSQScannerIP(
2139
- int d,
2140
- const std::vector<float>& trained,
2141
- size_t code_size,
2142
- bool store_pairs,
2143
- const IDSelector* sel,
2144
- bool by_residual)
2145
- : dc(d, trained), by_residual(by_residual), accu0(0) {
2146
- this->store_pairs = store_pairs;
2147
- this->sel = sel;
2148
- this->code_size = code_size;
2149
- this->keep_max = true;
2150
- }
2151
-
2152
- void set_query(const float* query) override {
2153
- dc.set_query(query);
2154
- }
2155
-
2156
- void set_list(idx_t list_no, float coarse_dis) override {
2157
- this->list_no = list_no;
2158
- accu0 = by_residual ? coarse_dis : 0;
2159
- }
2160
-
2161
- float distance_to_code(const uint8_t* code) const final {
2162
- return accu0 + dc.query_to_code(code);
2163
- }
2164
-
2165
- size_t scan_codes(
2166
- size_t list_size,
2167
- const uint8_t* codes,
2168
- const idx_t* ids,
2169
- float* simi,
2170
- idx_t* idxi,
2171
- size_t k) const override {
2172
- size_t nup = 0;
2173
-
2174
- for (size_t j = 0; j < list_size; j++, codes += code_size) {
2175
- if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
2176
- continue;
2177
- }
2178
-
2179
- float accu = accu0 + dc.query_to_code(codes);
2180
-
2181
- if (accu > simi[0]) {
2182
- int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
2183
- minheap_replace_top(k, simi, idxi, accu, id);
2184
- nup++;
2185
- }
2186
- }
2187
- return nup;
2188
- }
2189
-
2190
- void scan_codes_range(
2191
- size_t list_size,
2192
- const uint8_t* codes,
2193
- const idx_t* ids,
2194
- float radius,
2195
- RangeQueryResult& res) const override {
2196
- for (size_t j = 0; j < list_size; j++, codes += code_size) {
2197
- if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
2198
- continue;
2199
- }
2200
-
2201
- float accu = accu0 + dc.query_to_code(codes);
2202
- if (accu > radius) {
2203
- int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
2204
- res.add(accu, id);
2205
- }
2206
- }
2207
- }
2208
- };
2209
-
2210
- /* use_sel = 0: don't check selector
2211
- * = 1: check on ids[j]
2212
- * = 2: check in j directly (normally ids is nullptr and store_pairs)
2213
- */
2214
- template <class DCClass, int use_sel>
2215
- struct IVFSQScannerL2 : InvertedListScanner {
2216
- DCClass dc;
2217
-
2218
- bool by_residual;
2219
- const Index* quantizer;
2220
- const float* x; /// current query
2221
-
2222
- std::vector<float> tmp;
2223
-
2224
- IVFSQScannerL2(
2225
- int d,
2226
- const std::vector<float>& trained,
2227
- size_t code_size,
2228
- const Index* quantizer,
2229
- bool store_pairs,
2230
- const IDSelector* sel,
2231
- bool by_residual)
2232
- : dc(d, trained),
2233
- by_residual(by_residual),
2234
- quantizer(quantizer),
2235
- x(nullptr),
2236
- tmp(d) {
2237
- this->store_pairs = store_pairs;
2238
- this->sel = sel;
2239
- this->code_size = code_size;
2240
- }
2241
-
2242
- void set_query(const float* query) override {
2243
- x = query;
2244
- if (!quantizer) {
2245
- dc.set_query(query);
2246
- }
2247
- }
2248
-
2249
- void set_list(idx_t list_no, float /*coarse_dis*/) override {
2250
- this->list_no = list_no;
2251
- if (by_residual) {
2252
- // shift of x_in wrt centroid
2253
- quantizer->compute_residual(x, tmp.data(), list_no);
2254
- dc.set_query(tmp.data());
2255
- } else {
2256
- dc.set_query(x);
2257
- }
2258
- }
2259
-
2260
- float distance_to_code(const uint8_t* code) const final {
2261
- return dc.query_to_code(code);
2262
- }
2263
-
2264
- size_t scan_codes(
2265
- size_t list_size,
2266
- const uint8_t* codes,
2267
- const idx_t* ids,
2268
- float* simi,
2269
- idx_t* idxi,
2270
- size_t k) const override {
2271
- size_t nup = 0;
2272
- for (size_t j = 0; j < list_size; j++, codes += code_size) {
2273
- if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
2274
- continue;
2275
- }
2276
-
2277
- float dis = dc.query_to_code(codes);
2278
-
2279
- if (dis < simi[0]) {
2280
- int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
2281
- maxheap_replace_top(k, simi, idxi, dis, id);
2282
- nup++;
152
+ return with_simd_level([&]<SIMDLevel SL>() -> SQDistanceComputer* {
153
+ if constexpr (SL != SIMDLevel::NONE) {
154
+ auto* dc = scalar_quantizer::sq_select_distance_computer<SL>(
155
+ metric, qtype, d, trained);
156
+ if (dc) {
157
+ return dc;
2283
158
  }
2284
159
  }
2285
- return nup;
2286
- }
2287
-
2288
- void scan_codes_range(
2289
- size_t list_size,
2290
- const uint8_t* codes,
2291
- const idx_t* ids,
2292
- float radius,
2293
- RangeQueryResult& res) const override {
2294
- for (size_t j = 0; j < list_size; j++, codes += code_size) {
2295
- if (use_sel && !sel->is_member(use_sel == 1 ? ids[j] : j)) {
2296
- continue;
2297
- }
2298
-
2299
- float dis = dc.query_to_code(codes);
2300
- if (dis < radius) {
2301
- int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
2302
- res.add(dis, id);
2303
- }
2304
- }
2305
- }
2306
- };
2307
-
2308
- template <class DCClass, int use_sel>
2309
- InvertedListScanner* sel3_InvertedListScanner(
2310
- const ScalarQuantizer* sq,
2311
- const Index* quantizer,
2312
- bool store_pairs,
2313
- const IDSelector* sel,
2314
- bool r) {
2315
- if (DCClass::Sim::metric_type == METRIC_L2) {
2316
- return new IVFSQScannerL2<DCClass, use_sel>(
2317
- sq->d,
2318
- sq->trained,
2319
- sq->code_size,
2320
- quantizer,
2321
- store_pairs,
2322
- sel,
2323
- r);
2324
- } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
2325
- return new IVFSQScannerIP<DCClass, use_sel>(
2326
- sq->d, sq->trained, sq->code_size, store_pairs, sel, r);
2327
- } else {
2328
- FAISS_THROW_MSG("unsupported metric type");
2329
- }
2330
- }
2331
-
2332
- template <class DCClass>
2333
- InvertedListScanner* sel2_InvertedListScanner(
2334
- const ScalarQuantizer* sq,
2335
- const Index* quantizer,
2336
- bool store_pairs,
2337
- const IDSelector* sel,
2338
- bool r) {
2339
- if (sel) {
2340
- if (store_pairs) {
2341
- return sel3_InvertedListScanner<DCClass, 2>(
2342
- sq, quantizer, store_pairs, sel, r);
2343
- } else {
2344
- return sel3_InvertedListScanner<DCClass, 1>(
2345
- sq, quantizer, store_pairs, sel, r);
2346
- }
2347
- } else {
2348
- return sel3_InvertedListScanner<DCClass, 0>(
2349
- sq, quantizer, store_pairs, sel, r);
2350
- }
2351
- }
2352
-
2353
- template <class Similarity, class Codec, QuantizerTemplateScaling SCALING>
2354
- InvertedListScanner* sel12_InvertedListScanner(
2355
- const ScalarQuantizer* sq,
2356
- const Index* quantizer,
2357
- bool store_pairs,
2358
- const IDSelector* sel,
2359
- bool r) {
2360
- constexpr int SIMDWIDTH = Similarity::simdwidth;
2361
- using QuantizerClass = QuantizerTemplate<Codec, SCALING, SIMDWIDTH>;
2362
- using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
2363
- return sel2_InvertedListScanner<DCClass>(
2364
- sq, quantizer, store_pairs, sel, r);
2365
- }
2366
-
2367
- template <class Similarity>
2368
- InvertedListScanner* sel1_InvertedListScanner(
2369
- const ScalarQuantizer* sq,
2370
- const Index* quantizer,
2371
- bool store_pairs,
2372
- const IDSelector* sel,
2373
- bool r) {
2374
- constexpr int SIMDWIDTH = Similarity::simdwidth;
2375
- switch (sq->qtype) {
2376
- case ScalarQuantizer::QT_8bit_uniform:
2377
- return sel12_InvertedListScanner<
2378
- Similarity,
2379
- Codec8bit,
2380
- QuantizerTemplateScaling::UNIFORM>(
2381
- sq, quantizer, store_pairs, sel, r);
2382
- case ScalarQuantizer::QT_4bit_uniform:
2383
- return sel12_InvertedListScanner<
2384
- Similarity,
2385
- Codec4bit,
2386
- QuantizerTemplateScaling::UNIFORM>(
2387
- sq, quantizer, store_pairs, sel, r);
2388
- case ScalarQuantizer::QT_8bit:
2389
- return sel12_InvertedListScanner<
2390
- Similarity,
2391
- Codec8bit,
2392
- QuantizerTemplateScaling::NON_UNIFORM>(
2393
- sq, quantizer, store_pairs, sel, r);
2394
- case ScalarQuantizer::QT_4bit:
2395
- return sel12_InvertedListScanner<
2396
- Similarity,
2397
- Codec4bit,
2398
- QuantizerTemplateScaling::NON_UNIFORM>(
2399
- sq, quantizer, store_pairs, sel, r);
2400
- case ScalarQuantizer::QT_6bit:
2401
- return sel12_InvertedListScanner<
2402
- Similarity,
2403
- Codec6bit,
2404
- QuantizerTemplateScaling::NON_UNIFORM>(
2405
- sq, quantizer, store_pairs, sel, r);
2406
- case ScalarQuantizer::QT_fp16:
2407
- return sel2_InvertedListScanner<DCTemplate<
2408
- QuantizerFP16<SIMDWIDTH>,
2409
- Similarity,
2410
- SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
2411
- case ScalarQuantizer::QT_bf16:
2412
- return sel2_InvertedListScanner<DCTemplate<
2413
- QuantizerBF16<SIMDWIDTH>,
2414
- Similarity,
2415
- SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
2416
- case ScalarQuantizer::QT_8bit_direct:
2417
- #if defined(__AVX512F__)
2418
- if (sq->d % 32 == 0) {
2419
- return sel2_InvertedListScanner<
2420
- DistanceComputerByte<Similarity, SIMDWIDTH>>(
2421
- sq, quantizer, store_pairs, sel, r);
2422
- } else
2423
- #elif defined(__AVX2__)
2424
- if (sq->d % 16 == 0) {
2425
- return sel2_InvertedListScanner<
2426
- DistanceComputerByte<Similarity, SIMDWIDTH>>(
2427
- sq, quantizer, store_pairs, sel, r);
2428
- } else
2429
- #endif
2430
- {
2431
- return sel2_InvertedListScanner<DCTemplate<
2432
- Quantizer8bitDirect<SIMDWIDTH>,
2433
- Similarity,
2434
- SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
2435
- }
2436
- case ScalarQuantizer::QT_8bit_direct_signed:
2437
- return sel2_InvertedListScanner<DCTemplate<
2438
- Quantizer8bitDirectSigned<SIMDWIDTH>,
2439
- Similarity,
2440
- SIMDWIDTH>>(sq, quantizer, store_pairs, sel, r);
2441
- }
2442
-
2443
- FAISS_THROW_MSG("unknown qtype");
2444
- return nullptr;
2445
- }
2446
-
2447
- template <int SIMDWIDTH>
2448
- InvertedListScanner* sel0_InvertedListScanner(
2449
- MetricType mt,
2450
- const ScalarQuantizer* sq,
2451
- const Index* quantizer,
2452
- bool store_pairs,
2453
- const IDSelector* sel,
2454
- bool by_residual) {
2455
- if (mt == METRIC_L2) {
2456
- return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH>>(
2457
- sq, quantizer, store_pairs, sel, by_residual);
2458
- } else if (mt == METRIC_INNER_PRODUCT) {
2459
- return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH>>(
2460
- sq, quantizer, store_pairs, sel, by_residual);
2461
- } else {
2462
- FAISS_THROW_MSG("unsupported metric type");
2463
- }
160
+ return scalar_quantizer::sq_select_distance_computer<SIMDLevel::NONE>(
161
+ metric, qtype, d, trained);
162
+ });
2464
163
  }
2465
164
 
2466
- } // anonymous namespace
2467
-
2468
165
  InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
2469
166
  MetricType mt,
2470
167
  const Index* quantizer,
2471
168
  bool store_pairs,
2472
169
  const IDSelector* sel,
2473
170
  bool by_residual) const {
2474
- #if defined(USE_AVX512_F16C)
2475
- if (d % 16 == 0) {
2476
- return sel0_InvertedListScanner<16>(
2477
- mt, this, quantizer, store_pairs, sel, by_residual);
2478
- } else
2479
- #elif defined(USE_F16C) || defined(USE_NEON)
2480
- if (d % 8 == 0) {
2481
- return sel0_InvertedListScanner<8>(
2482
- mt, this, quantizer, store_pairs, sel, by_residual);
2483
- } else
2484
- #endif
2485
- {
2486
- return sel0_InvertedListScanner<1>(
2487
- mt, this, quantizer, store_pairs, sel, by_residual);
2488
- }
171
+ return with_simd_level([&]<SIMDLevel SL>() -> InvertedListScanner* {
172
+ if constexpr (SL != SIMDLevel::NONE) {
173
+ auto* s = scalar_quantizer::sq_select_InvertedListScanner<SL>(
174
+ qtype,
175
+ mt,
176
+ d,
177
+ code_size,
178
+ trained,
179
+ quantizer,
180
+ store_pairs,
181
+ sel,
182
+ by_residual);
183
+ if (s) {
184
+ return s;
185
+ }
186
+ }
187
+ return scalar_quantizer::sq_select_InvertedListScanner<SIMDLevel::NONE>(
188
+ qtype,
189
+ mt,
190
+ d,
191
+ code_size,
192
+ trained,
193
+ quantizer,
194
+ store_pairs,
195
+ sel,
196
+ by_residual);
197
+ });
2489
198
  }
2490
199
 
2491
200
  } // namespace faiss