faiss 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (202) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/lib/faiss/version.rb +1 -1
  4. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  5. data/vendor/faiss/faiss/AutoTune.h +55 -56
  6. data/vendor/faiss/faiss/Clustering.cpp +334 -195
  7. data/vendor/faiss/faiss/Clustering.h +88 -35
  8. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  9. data/vendor/faiss/faiss/IVFlib.h +48 -51
  10. data/vendor/faiss/faiss/Index.cpp +85 -103
  11. data/vendor/faiss/faiss/Index.h +54 -48
  12. data/vendor/faiss/faiss/Index2Layer.cpp +139 -164
  13. data/vendor/faiss/faiss/Index2Layer.h +22 -22
  14. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  15. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  16. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  17. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  18. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  19. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  20. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  21. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  22. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  23. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  24. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  25. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  26. data/vendor/faiss/faiss/IndexFlat.cpp +116 -147
  27. data/vendor/faiss/faiss/IndexFlat.h +35 -46
  28. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  29. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  30. data/vendor/faiss/faiss/IndexIVF.cpp +474 -454
  31. data/vendor/faiss/faiss/IndexIVF.h +146 -113
  32. data/vendor/faiss/faiss/IndexIVFFlat.cpp +248 -250
  33. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  34. data/vendor/faiss/faiss/IndexIVFPQ.cpp +457 -516
  35. data/vendor/faiss/faiss/IndexIVFPQ.h +74 -66
  36. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  38. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  39. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  40. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +125 -133
  41. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +19 -21
  42. data/vendor/faiss/faiss/IndexLSH.cpp +75 -96
  43. data/vendor/faiss/faiss/IndexLSH.h +21 -26
  44. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  45. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  46. data/vendor/faiss/faiss/IndexNNDescent.cpp +231 -0
  47. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  48. data/vendor/faiss/faiss/IndexNSG.cpp +303 -0
  49. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  50. data/vendor/faiss/faiss/IndexPQ.cpp +405 -464
  51. data/vendor/faiss/faiss/IndexPQ.h +64 -67
  52. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  53. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  54. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  55. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  56. data/vendor/faiss/faiss/IndexRefine.cpp +115 -131
  57. data/vendor/faiss/faiss/IndexRefine.h +22 -23
  58. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  59. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  60. data/vendor/faiss/faiss/IndexResidual.cpp +291 -0
  61. data/vendor/faiss/faiss/IndexResidual.h +152 -0
  62. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +120 -155
  63. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -45
  64. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  65. data/vendor/faiss/faiss/IndexShards.h +85 -73
  66. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  67. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  68. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  69. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  70. data/vendor/faiss/faiss/MetricType.h +7 -7
  71. data/vendor/faiss/faiss/VectorTransform.cpp +652 -474
  72. data/vendor/faiss/faiss/VectorTransform.h +61 -89
  73. data/vendor/faiss/faiss/clone_index.cpp +77 -73
  74. data/vendor/faiss/faiss/clone_index.h +4 -9
  75. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  76. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  77. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +197 -170
  78. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  79. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  82. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  84. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  85. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  89. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  90. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  91. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  92. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  93. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  94. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  95. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  96. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  97. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  98. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  99. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  100. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  101. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  102. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  103. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  104. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  105. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  106. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  107. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  108. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  110. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  111. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  112. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  113. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  114. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  115. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  116. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  117. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  118. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  119. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +270 -0
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +115 -0
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  125. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  126. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  127. data/vendor/faiss/faiss/impl/HNSW.cpp +595 -611
  128. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +672 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +172 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +682 -0
  134. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  135. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  136. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  137. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  139. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +448 -0
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +130 -0
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +648 -701
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  145. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  146. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  147. data/vendor/faiss/faiss/impl/index_read.cpp +547 -479
  148. data/vendor/faiss/faiss/impl/index_write.cpp +497 -407
  149. data/vendor/faiss/faiss/impl/io.cpp +75 -94
  150. data/vendor/faiss/faiss/impl/io.h +31 -41
  151. data/vendor/faiss/faiss/impl/io_macros.h +40 -29
  152. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  153. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  154. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  159. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  160. data/vendor/faiss/faiss/index_factory.cpp +269 -218
  161. data/vendor/faiss/faiss/index_factory.h +6 -7
  162. data/vendor/faiss/faiss/index_io.h +23 -26
  163. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  164. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  165. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  166. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  167. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  168. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  169. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  170. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  172. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  173. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  174. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  175. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  176. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  177. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  178. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  179. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  180. data/vendor/faiss/faiss/utils/distances.cpp +301 -310
  181. data/vendor/faiss/faiss/utils/distances.h +133 -118
  182. data/vendor/faiss/faiss/utils/distances_simd.cpp +456 -516
  183. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  184. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  185. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  186. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  187. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  188. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  189. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  190. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  191. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  192. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  193. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  194. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  195. data/vendor/faiss/faiss/utils/random.h +13 -16
  196. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  197. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  198. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  199. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  200. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  201. data/vendor/faiss/faiss/utils/utils.h +53 -48
  202. metadata +20 -2
@@ -9,18 +9,19 @@
9
9
 
10
10
  #include <faiss/impl/ScalarQuantizer.h>
11
11
 
12
- #include <cstdio>
13
12
  #include <algorithm>
13
+ #include <cstdio>
14
14
 
15
+ #include <faiss/impl/platform_macros.h>
15
16
  #include <omp.h>
16
17
 
17
18
  #ifdef __SSE__
18
19
  #include <immintrin.h>
19
20
  #endif
20
21
 
21
- #include <faiss/utils/utils.h>
22
- #include <faiss/impl/FaissAssert.h>
23
22
  #include <faiss/impl/AuxIndexStructures.h>
23
+ #include <faiss/impl/FaissAssert.h>
24
+ #include <faiss/utils/utils.h>
24
25
 
25
26
  namespace faiss {
26
27
 
@@ -43,11 +44,11 @@ namespace faiss {
43
44
  #ifdef __F16C__
44
45
  #define USE_F16C
45
46
  #else
46
- #warning "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well"
47
+ #warning \
48
+ "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well"
47
49
  #endif
48
50
  #endif
49
51
 
50
-
51
52
  namespace {
52
53
 
53
54
  typedef Index::idx_t idx_t;
@@ -55,7 +56,6 @@ typedef ScalarQuantizer::QuantizerType QuantizerType;
55
56
  typedef ScalarQuantizer::RangeStat RangeStat;
56
57
  using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
57
58
 
58
-
59
59
  /*******************************************************************
60
60
  * Codec: converts between values in [0, 1] and an index in a code
61
61
  * array. The "i" parameter is the vector component index (not byte
@@ -63,108 +63,103 @@ using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
63
63
  */
64
64
 
65
65
  struct Codec8bit {
66
-
67
- static void encode_component (float x, uint8_t *code, int i) {
66
+ static void encode_component(float x, uint8_t* code, int i) {
68
67
  code[i] = (int)(255 * x);
69
68
  }
70
69
 
71
- static float decode_component (const uint8_t *code, int i) {
70
+ static float decode_component(const uint8_t* code, int i) {
72
71
  return (code[i] + 0.5f) / 255.0f;
73
72
  }
74
73
 
75
74
  #ifdef __AVX2__
76
- static __m256 decode_8_components (const uint8_t *code, int i) {
75
+ static __m256 decode_8_components(const uint8_t* code, int i) {
77
76
  uint64_t c8 = *(uint64_t*)(code + i);
78
- __m128i c4lo = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8));
79
- __m128i c4hi = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8 >> 32));
77
+ __m128i c4lo = _mm_cvtepu8_epi32(_mm_set1_epi32(c8));
78
+ __m128i c4hi = _mm_cvtepu8_epi32(_mm_set1_epi32(c8 >> 32));
80
79
  // __m256i i8 = _mm256_set_m128i(c4lo, c4hi);
81
- __m256i i8 = _mm256_castsi128_si256 (c4lo);
82
- i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
83
- __m256 f8 = _mm256_cvtepi32_ps (i8);
84
- __m256 half = _mm256_set1_ps (0.5f);
85
- f8 += half;
86
- __m256 one_255 = _mm256_set1_ps (1.f / 255.f);
87
- return f8 * one_255;
80
+ __m256i i8 = _mm256_castsi128_si256(c4lo);
81
+ i8 = _mm256_insertf128_si256(i8, c4hi, 1);
82
+ __m256 f8 = _mm256_cvtepi32_ps(i8);
83
+ __m256 half = _mm256_set1_ps(0.5f);
84
+ f8 = _mm256_add_ps(f8, half);
85
+ __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
86
+ return _mm256_mul_ps(f8, one_255);
88
87
  }
89
88
  #endif
90
89
  };
91
90
 
92
-
93
91
  struct Codec4bit {
94
-
95
- static void encode_component (float x, uint8_t *code, int i) {
96
- code [i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
92
+ static void encode_component(float x, uint8_t* code, int i) {
93
+ code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
97
94
  }
98
95
 
99
- static float decode_component (const uint8_t *code, int i) {
96
+ static float decode_component(const uint8_t* code, int i) {
100
97
  return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
101
98
  }
102
99
 
103
-
104
100
  #ifdef __AVX2__
105
- static __m256 decode_8_components (const uint8_t *code, int i) {
101
+ static __m256 decode_8_components(const uint8_t* code, int i) {
106
102
  uint32_t c4 = *(uint32_t*)(code + (i >> 1));
107
103
  uint32_t mask = 0x0f0f0f0f;
108
104
  uint32_t c4ev = c4 & mask;
109
105
  uint32_t c4od = (c4 >> 4) & mask;
110
106
 
111
107
  // the 8 lower bytes of c8 contain the values
112
- __m128i c8 = _mm_unpacklo_epi8 (_mm_set1_epi32(c4ev),
113
- _mm_set1_epi32(c4od));
114
- __m128i c4lo = _mm_cvtepu8_epi32 (c8);
115
- __m128i c4hi = _mm_cvtepu8_epi32 (_mm_srli_si128(c8, 4));
116
- __m256i i8 = _mm256_castsi128_si256 (c4lo);
117
- i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
118
- __m256 f8 = _mm256_cvtepi32_ps (i8);
119
- __m256 half = _mm256_set1_ps (0.5f);
120
- f8 += half;
121
- __m256 one_255 = _mm256_set1_ps (1.f / 15.f);
122
- return f8 * one_255;
108
+ __m128i c8 =
109
+ _mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od));
110
+ __m128i c4lo = _mm_cvtepu8_epi32(c8);
111
+ __m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4));
112
+ __m256i i8 = _mm256_castsi128_si256(c4lo);
113
+ i8 = _mm256_insertf128_si256(i8, c4hi, 1);
114
+ __m256 f8 = _mm256_cvtepi32_ps(i8);
115
+ __m256 half = _mm256_set1_ps(0.5f);
116
+ f8 = _mm256_add_ps(f8, half);
117
+ __m256 one_255 = _mm256_set1_ps(1.f / 15.f);
118
+ return _mm256_mul_ps(f8, one_255);
123
119
  }
124
120
  #endif
125
121
  };
126
122
 
127
123
  struct Codec6bit {
128
-
129
- static void encode_component (float x, uint8_t *code, int i) {
124
+ static void encode_component(float x, uint8_t* code, int i) {
130
125
  int bits = (int)(x * 63.0);
131
126
  code += (i >> 2) * 3;
132
- switch(i & 3) {
133
- case 0:
134
- code[0] |= bits;
135
- break;
136
- case 1:
137
- code[0] |= bits << 6;
138
- code[1] |= bits >> 2;
139
- break;
140
- case 2:
141
- code[1] |= bits << 4;
142
- code[2] |= bits >> 4;
143
- break;
144
- case 3:
145
- code[2] |= bits << 2;
146
- break;
127
+ switch (i & 3) {
128
+ case 0:
129
+ code[0] |= bits;
130
+ break;
131
+ case 1:
132
+ code[0] |= bits << 6;
133
+ code[1] |= bits >> 2;
134
+ break;
135
+ case 2:
136
+ code[1] |= bits << 4;
137
+ code[2] |= bits >> 4;
138
+ break;
139
+ case 3:
140
+ code[2] |= bits << 2;
141
+ break;
147
142
  }
148
143
  }
149
144
 
150
- static float decode_component (const uint8_t *code, int i) {
145
+ static float decode_component(const uint8_t* code, int i) {
151
146
  uint8_t bits;
152
147
  code += (i >> 2) * 3;
153
- switch(i & 3) {
154
- case 0:
155
- bits = code[0] & 0x3f;
156
- break;
157
- case 1:
158
- bits = code[0] >> 6;
159
- bits |= (code[1] & 0xf) << 2;
160
- break;
161
- case 2:
162
- bits = code[1] >> 4;
163
- bits |= (code[2] & 3) << 4;
164
- break;
165
- case 3:
166
- bits = code[2] >> 2;
167
- break;
148
+ switch (i & 3) {
149
+ case 0:
150
+ bits = code[0] & 0x3f;
151
+ break;
152
+ case 1:
153
+ bits = code[0] >> 6;
154
+ bits |= (code[1] & 0xf) << 2;
155
+ break;
156
+ case 2:
157
+ bits = code[1] >> 4;
158
+ bits |= (code[2] & 3) << 4;
159
+ break;
160
+ case 3:
161
+ bits = code[2] >> 2;
162
+ break;
168
163
  }
169
164
  return (bits + 0.5f) / 63.0f;
170
165
  }
@@ -173,12 +168,14 @@ struct Codec6bit {
173
168
 
174
169
  /* Load 6 bytes that represent 8 6-bit values, return them as a
175
170
  * 8*32 bit vector register */
176
- static __m256i load6 (const uint16_t *code16) {
177
- const __m128i perm = _mm_set_epi8(-1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
171
+ static __m256i load6(const uint16_t* code16) {
172
+ const __m128i perm = _mm_set_epi8(
173
+ -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
178
174
  const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0);
179
175
 
180
176
  // load 6 bytes
181
- __m128i c1 = _mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]);
177
+ __m128i c1 =
178
+ _mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]);
182
179
 
183
180
  // put in 8 * 32 bits
184
181
  __m128i c2 = _mm_shuffle_epi8(c1, perm);
@@ -190,37 +187,33 @@ struct Codec6bit {
190
187
  return c5;
191
188
  }
192
189
 
193
- static __m256 decode_8_components (const uint8_t *code, int i) {
194
- __m256i i8 = load6 ((const uint16_t *)(code + (i >> 2) * 3));
195
- __m256 f8 = _mm256_cvtepi32_ps (i8);
190
+ static __m256 decode_8_components(const uint8_t* code, int i) {
191
+ __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3));
192
+ __m256 f8 = _mm256_cvtepi32_ps(i8);
196
193
  // this could also be done with bit manipulations but it is
197
194
  // not obviously faster
198
- __m256 half = _mm256_set1_ps (0.5f);
199
- f8 += half;
200
- __m256 one_63 = _mm256_set1_ps (1.f / 63.f);
201
- return f8 * one_63;
195
+ __m256 half = _mm256_set1_ps(0.5f);
196
+ f8 = _mm256_add_ps(f8, half);
197
+ __m256 one_63 = _mm256_set1_ps(1.f / 63.f);
198
+ return _mm256_mul_ps(f8, one_63);
202
199
  }
203
200
 
204
201
  #endif
205
202
  };
206
203
 
207
-
208
-
209
204
  #ifdef USE_F16C
210
205
 
211
-
212
- uint16_t encode_fp16 (float x) {
213
- __m128 xf = _mm_set1_ps (x);
214
- __m128i xi = _mm_cvtps_ph (
215
- xf, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
216
- return _mm_cvtsi128_si32 (xi) & 0xffff;
206
+ uint16_t encode_fp16(float x) {
207
+ __m128 xf = _mm_set1_ps(x);
208
+ __m128i xi =
209
+ _mm_cvtps_ph(xf, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
210
+ return _mm_cvtsi128_si32(xi) & 0xffff;
217
211
  }
218
212
 
219
-
220
- float decode_fp16 (uint16_t x) {
221
- __m128i xi = _mm_set1_epi16 (x);
222
- __m128 xf = _mm_cvtph_ps (xi);
223
- return _mm_cvtss_f32 (xf);
213
+ float decode_fp16(uint16_t x) {
214
+ __m128i xi = _mm_set1_epi16(x);
215
+ __m128 xf = _mm_cvtph_ps(xi);
216
+ return _mm_cvtss_f32(xf);
224
217
  }
225
218
 
226
219
  #else
@@ -228,19 +221,17 @@ float decode_fp16 (uint16_t x) {
228
221
  // non-intrinsic FP16 <-> FP32 code adapted from
229
222
  // https://github.com/ispc/ispc/blob/master/stdlib.ispc
230
223
 
231
- float floatbits (uint32_t x) {
232
- void *xptr = &x;
224
+ float floatbits(uint32_t x) {
225
+ void* xptr = &x;
233
226
  return *(float*)xptr;
234
227
  }
235
228
 
236
- uint32_t intbits (float f) {
237
- void *fptr = &f;
229
+ uint32_t intbits(float f) {
230
+ void* fptr = &f;
238
231
  return *(uint32_t*)fptr;
239
232
  }
240
233
 
241
-
242
- uint16_t encode_fp16 (float f) {
243
-
234
+ uint16_t encode_fp16(float f) {
244
235
  // via Fabian "ryg" Giesen.
245
236
  // https://gist.github.com/2156668
246
237
  uint32_t sign_mask = 0x80000000u;
@@ -297,20 +288,19 @@ uint16_t encode_fp16 (float f) {
297
288
  return (o | (sign >> 16));
298
289
  }
299
290
 
300
- float decode_fp16 (uint16_t h) {
301
-
291
+ float decode_fp16(uint16_t h) {
302
292
  // https://gist.github.com/2144712
303
293
  // Fabian "ryg" Giesen.
304
294
 
305
295
  const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
306
296
 
307
- int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
308
- int32_t exp = shifted_exp & o; // just the exponent
309
- o += (int32_t)(127 - 15) << 23; // exponent adjust
297
+ int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
298
+ int32_t exp = shifted_exp & o; // just the exponent
299
+ o += (int32_t)(127 - 15) << 23; // exponent adjust
310
300
 
311
301
  int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
312
- int32_t zerodenorm_val = intbits(
313
- floatbits(o + (1u<<23)) - floatbits(113u << 23));
302
+ int32_t zerodenorm_val =
303
+ intbits(floatbits(o + (1u << 23)) - floatbits(113u << 23));
314
304
  int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
315
305
 
316
306
  int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
@@ -319,30 +309,21 @@ float decode_fp16 (uint16_t h) {
319
309
 
320
310
  #endif
321
311
 
322
-
323
-
324
312
  /*******************************************************************
325
313
  * Quantizer: normalizes scalar vector components, then passes them
326
314
  * through a codec
327
315
  *******************************************************************/
328
316
 
329
-
330
-
331
-
332
-
333
- template<class Codec, bool uniform, int SIMD>
317
+ template <class Codec, bool uniform, int SIMD>
334
318
  struct QuantizerTemplate {};
335
319
 
336
-
337
- template<class Codec>
338
- struct QuantizerTemplate<Codec, true, 1>: ScalarQuantizer::Quantizer {
320
+ template <class Codec>
321
+ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::Quantizer {
339
322
  const size_t d;
340
323
  const float vmin, vdiff;
341
324
 
342
- QuantizerTemplate(size_t d, const std::vector<float> &trained):
343
- d(d), vmin(trained[0]), vdiff(trained[1])
344
- {
345
- }
325
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
326
+ : d(d), vmin(trained[0]), vdiff(trained[1]) {}
346
327
 
347
328
  void encode_vector(const float* x, uint8_t* code) const final {
348
329
  for (size_t i = 0; i < d; i++) {
@@ -367,43 +348,36 @@ struct QuantizerTemplate<Codec, true, 1>: ScalarQuantizer::Quantizer {
367
348
  }
368
349
  }
369
350
 
370
- float reconstruct_component (const uint8_t * code, int i) const
371
- {
372
- float xi = Codec::decode_component (code, i);
351
+ float reconstruct_component(const uint8_t* code, int i) const {
352
+ float xi = Codec::decode_component(code, i);
373
353
  return vmin + xi * vdiff;
374
354
  }
375
-
376
355
  };
377
356
 
378
-
379
-
380
357
  #ifdef __AVX2__
381
358
 
382
- template<class Codec>
383
- struct QuantizerTemplate<Codec, true, 8>: QuantizerTemplate<Codec, true, 1> {
384
-
385
- QuantizerTemplate (size_t d, const std::vector<float> &trained):
386
- QuantizerTemplate<Codec, true, 1> (d, trained) {}
359
+ template <class Codec>
360
+ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
361
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
362
+ : QuantizerTemplate<Codec, true, 1>(d, trained) {}
387
363
 
388
- __m256 reconstruct_8_components (const uint8_t * code, int i) const
389
- {
390
- __m256 xi = Codec::decode_8_components (code, i);
391
- return _mm256_set1_ps(this->vmin) + xi * _mm256_set1_ps (this->vdiff);
364
+ __m256 reconstruct_8_components(const uint8_t* code, int i) const {
365
+ __m256 xi = Codec::decode_8_components(code, i);
366
+ return _mm256_add_ps(
367
+ _mm256_set1_ps(this->vmin),
368
+ _mm256_mul_ps(xi, _mm256_set1_ps(this->vdiff)));
392
369
  }
393
-
394
370
  };
395
371
 
396
372
  #endif
397
373
 
398
-
399
-
400
- template<class Codec>
401
- struct QuantizerTemplate<Codec, false, 1>: ScalarQuantizer::Quantizer {
374
+ template <class Codec>
375
+ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::Quantizer {
402
376
  const size_t d;
403
377
  const float *vmin, *vdiff;
404
378
 
405
- QuantizerTemplate (size_t d, const std::vector<float> &trained):
406
- d(d), vmin(trained.data()), vdiff(trained.data() + d) {}
379
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
380
+ : d(d), vmin(trained.data()), vdiff(trained.data() + d) {}
407
381
 
408
382
  void encode_vector(const float* x, uint8_t* code) const final {
409
383
  for (size_t i = 0; i < d; i++) {
@@ -428,30 +402,25 @@ struct QuantizerTemplate<Codec, false, 1>: ScalarQuantizer::Quantizer {
428
402
  }
429
403
  }
430
404
 
431
- float reconstruct_component (const uint8_t * code, int i) const
432
- {
433
- float xi = Codec::decode_component (code, i);
405
+ float reconstruct_component(const uint8_t* code, int i) const {
406
+ float xi = Codec::decode_component(code, i);
434
407
  return vmin[i] + xi * vdiff[i];
435
408
  }
436
-
437
409
  };
438
410
 
439
-
440
411
  #ifdef __AVX2__
441
412
 
442
- template<class Codec>
443
- struct QuantizerTemplate<Codec, false, 8>: QuantizerTemplate<Codec, false, 1> {
413
+ template <class Codec>
414
+ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
415
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
416
+ : QuantizerTemplate<Codec, false, 1>(d, trained) {}
444
417
 
445
- QuantizerTemplate (size_t d, const std::vector<float> &trained):
446
- QuantizerTemplate<Codec, false, 1> (d, trained) {}
447
-
448
- __m256 reconstruct_8_components (const uint8_t * code, int i) const
449
- {
450
- __m256 xi = Codec::decode_8_components (code, i);
451
- return _mm256_loadu_ps (this->vmin + i) + xi * _mm256_loadu_ps (this->vdiff + i);
418
+ __m256 reconstruct_8_components(const uint8_t* code, int i) const {
419
+ __m256 xi = Codec::decode_8_components(code, i);
420
+ return _mm256_add_ps(
421
+ _mm256_loadu_ps(this->vmin + i),
422
+ _mm256_mul_ps(xi, _mm256_loadu_ps(this->vdiff + i)));
452
423
  }
453
-
454
-
455
424
  };
456
425
 
457
426
  #endif
@@ -460,15 +429,14 @@ struct QuantizerTemplate<Codec, false, 8>: QuantizerTemplate<Codec, false, 1> {
460
429
  * FP16 quantizer
461
430
  *******************************************************************/
462
431
 
463
- template<int SIMDWIDTH>
432
+ template <int SIMDWIDTH>
464
433
  struct QuantizerFP16 {};
465
434
 
466
- template<>
467
- struct QuantizerFP16<1>: ScalarQuantizer::Quantizer {
435
+ template <>
436
+ struct QuantizerFP16<1> : ScalarQuantizer::Quantizer {
468
437
  const size_t d;
469
438
 
470
- QuantizerFP16(size_t d, const std::vector<float> & /* unused */):
471
- d(d) {}
439
+ QuantizerFP16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
472
440
 
473
441
  void encode_vector(const float* x, uint8_t* code) const final {
474
442
  for (size_t i = 0; i < d; i++) {
@@ -482,27 +450,22 @@ struct QuantizerFP16<1>: ScalarQuantizer::Quantizer {
482
450
  }
483
451
  }
484
452
 
485
- float reconstruct_component (const uint8_t * code, int i) const
486
- {
453
+ float reconstruct_component(const uint8_t* code, int i) const {
487
454
  return decode_fp16(((uint16_t*)code)[i]);
488
455
  }
489
-
490
456
  };
491
457
 
492
458
  #ifdef USE_F16C
493
459
 
494
- template<>
495
- struct QuantizerFP16<8>: QuantizerFP16<1> {
496
-
497
- QuantizerFP16 (size_t d, const std::vector<float> &trained):
498
- QuantizerFP16<1> (d, trained) {}
460
+ template <>
461
+ struct QuantizerFP16<8> : QuantizerFP16<1> {
462
+ QuantizerFP16(size_t d, const std::vector<float>& trained)
463
+ : QuantizerFP16<1>(d, trained) {}
499
464
 
500
- __m256 reconstruct_8_components (const uint8_t * code, int i) const
501
- {
502
- __m128i codei = _mm_loadu_si128 ((const __m128i*)(code + 2 * i));
503
- return _mm256_cvtph_ps (codei);
465
+ __m256 reconstruct_8_components(const uint8_t* code, int i) const {
466
+ __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i));
467
+ return _mm256_cvtph_ps(codei);
504
468
  }
505
-
506
469
  };
507
470
 
508
471
  #endif
@@ -511,16 +474,15 @@ struct QuantizerFP16<8>: QuantizerFP16<1> {
511
474
  * 8bit_direct quantizer
512
475
  *******************************************************************/
513
476
 
514
- template<int SIMDWIDTH>
477
+ template <int SIMDWIDTH>
515
478
  struct Quantizer8bitDirect {};
516
479
 
517
- template<>
518
- struct Quantizer8bitDirect<1>: ScalarQuantizer::Quantizer {
480
+ template <>
481
+ struct Quantizer8bitDirect<1> : ScalarQuantizer::Quantizer {
519
482
  const size_t d;
520
483
 
521
- Quantizer8bitDirect(size_t d, const std::vector<float> & /* unused */):
522
- d(d) {}
523
-
484
+ Quantizer8bitDirect(size_t d, const std::vector<float>& /* unused */)
485
+ : d(d) {}
524
486
 
525
487
  void encode_vector(const float* x, uint8_t* code) const final {
526
488
  for (size_t i = 0; i < d; i++) {
@@ -534,82 +496,83 @@ struct Quantizer8bitDirect<1>: ScalarQuantizer::Quantizer {
534
496
  }
535
497
  }
536
498
 
537
- float reconstruct_component (const uint8_t * code, int i) const
538
- {
499
+ float reconstruct_component(const uint8_t* code, int i) const {
539
500
  return code[i];
540
501
  }
541
-
542
502
  };
543
503
 
544
504
  #ifdef __AVX2__
545
505
 
546
- template<>
547
- struct Quantizer8bitDirect<8>: Quantizer8bitDirect<1> {
548
-
549
- Quantizer8bitDirect (size_t d, const std::vector<float> &trained):
550
- Quantizer8bitDirect<1> (d, trained) {}
506
+ template <>
507
+ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
508
+ Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
509
+ : Quantizer8bitDirect<1>(d, trained) {}
551
510
 
552
- __m256 reconstruct_8_components (const uint8_t * code, int i) const
553
- {
511
+ __m256 reconstruct_8_components(const uint8_t* code, int i) const {
554
512
  __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
555
- __m256i y8 = _mm256_cvtepu8_epi32 (x8); // 8 * int32
556
- return _mm256_cvtepi32_ps (y8); // 8 * float32
513
+ __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
514
+ return _mm256_cvtepi32_ps(y8); // 8 * float32
557
515
  }
558
-
559
516
  };
560
517
 
561
518
  #endif
562
519
 
563
-
564
- template<int SIMDWIDTH>
565
- ScalarQuantizer::Quantizer *select_quantizer_1 (
566
- QuantizerType qtype,
567
- size_t d, const std::vector<float> & trained)
568
- {
569
- switch(qtype) {
570
- case ScalarQuantizer::QT_8bit:
571
- return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(d, trained);
572
- case ScalarQuantizer::QT_6bit:
573
- return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(d, trained);
574
- case ScalarQuantizer::QT_4bit:
575
- return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(d, trained);
576
- case ScalarQuantizer::QT_8bit_uniform:
577
- return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(d, trained);
578
- case ScalarQuantizer::QT_4bit_uniform:
579
- return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(d, trained);
580
- case ScalarQuantizer::QT_fp16:
581
- return new QuantizerFP16<SIMDWIDTH> (d, trained);
582
- case ScalarQuantizer::QT_8bit_direct:
583
- return new Quantizer8bitDirect<SIMDWIDTH> (d, trained);
584
- }
585
- FAISS_THROW_MSG ("unknown qtype");
520
+ template <int SIMDWIDTH>
521
+ ScalarQuantizer::Quantizer* select_quantizer_1(
522
+ QuantizerType qtype,
523
+ size_t d,
524
+ const std::vector<float>& trained) {
525
+ switch (qtype) {
526
+ case ScalarQuantizer::QT_8bit:
527
+ return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(
528
+ d, trained);
529
+ case ScalarQuantizer::QT_6bit:
530
+ return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(
531
+ d, trained);
532
+ case ScalarQuantizer::QT_4bit:
533
+ return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(
534
+ d, trained);
535
+ case ScalarQuantizer::QT_8bit_uniform:
536
+ return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(
537
+ d, trained);
538
+ case ScalarQuantizer::QT_4bit_uniform:
539
+ return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(
540
+ d, trained);
541
+ case ScalarQuantizer::QT_fp16:
542
+ return new QuantizerFP16<SIMDWIDTH>(d, trained);
543
+ case ScalarQuantizer::QT_8bit_direct:
544
+ return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
545
+ }
546
+ FAISS_THROW_MSG("unknown qtype");
586
547
  }
587
548
 
588
-
589
-
590
-
591
549
  /*******************************************************************
592
550
  * Quantizer range training
593
551
  */
594
552
 
595
- static float sqr (float x) {
553
+ static float sqr(float x) {
596
554
  return x * x;
597
555
  }
598
556
 
599
-
600
- void train_Uniform(RangeStat rs, float rs_arg,
601
- idx_t n, int k, const float *x,
602
- std::vector<float> & trained)
603
- {
604
- trained.resize (2);
605
- float & vmin = trained[0];
606
- float & vmax = trained[1];
557
+ void train_Uniform(
558
+ RangeStat rs,
559
+ float rs_arg,
560
+ idx_t n,
561
+ int k,
562
+ const float* x,
563
+ std::vector<float>& trained) {
564
+ trained.resize(2);
565
+ float& vmin = trained[0];
566
+ float& vmax = trained[1];
607
567
 
608
568
  if (rs == ScalarQuantizer::RS_minmax) {
609
- vmin = HUGE_VAL; vmax = -HUGE_VAL;
569
+ vmin = HUGE_VAL;
570
+ vmax = -HUGE_VAL;
610
571
  for (size_t i = 0; i < n; i++) {
611
- if (x[i] < vmin) vmin = x[i];
612
- if (x[i] > vmax) vmax = x[i];
572
+ if (x[i] < vmin)
573
+ vmin = x[i];
574
+ if (x[i] > vmax)
575
+ vmax = x[i];
613
576
  }
614
577
  float vexp = (vmax - vmin) * rs_arg;
615
578
  vmin -= vexp;
@@ -624,16 +587,18 @@ void train_Uniform(RangeStat rs, float rs_arg,
624
587
  float var = sum2 / n - mean * mean;
625
588
  float std = var <= 0 ? 1.0 : sqrt(var);
626
589
 
627
- vmin = mean - std * rs_arg ;
628
- vmax = mean + std * rs_arg ;
590
+ vmin = mean - std * rs_arg;
591
+ vmax = mean + std * rs_arg;
629
592
  } else if (rs == ScalarQuantizer::RS_quantiles) {
630
593
  std::vector<float> x_copy(n);
631
594
  memcpy(x_copy.data(), x, n * sizeof(*x));
632
595
  // TODO just do a qucikselect
633
596
  std::sort(x_copy.begin(), x_copy.end());
634
597
  int o = int(rs_arg * n);
635
- if (o < 0) o = 0;
636
- if (o > n - o) o = n / 2;
598
+ if (o < 0)
599
+ o = 0;
600
+ if (o > n - o)
601
+ o = n / 2;
637
602
  vmin = x_copy[o];
638
603
  vmax = x_copy[n - 1 - o];
639
604
 
@@ -643,8 +608,10 @@ void train_Uniform(RangeStat rs, float rs_arg,
643
608
  {
644
609
  vmin = HUGE_VAL, vmax = -HUGE_VAL;
645
610
  for (size_t i = 0; i < n; i++) {
646
- if (x[i] < vmin) vmin = x[i];
647
- if (x[i] > vmax) vmax = x[i];
611
+ if (x[i] < vmin)
612
+ vmin = x[i];
613
+ if (x[i] > vmax)
614
+ vmax = x[i];
648
615
  sx += x[i];
649
616
  }
650
617
  b = vmin;
@@ -659,62 +626,71 @@ void train_Uniform(RangeStat rs, float rs_arg,
659
626
 
660
627
  for (idx_t i = 0; i < n; i++) {
661
628
  float xi = x[i];
662
- float ni = floor ((xi - b) / a + 0.5);
663
- if (ni < 0) ni = 0;
664
- if (ni >= k) ni = k - 1;
665
- err1 += sqr (xi - (ni * a + b));
666
- sn += ni;
629
+ float ni = floor((xi - b) / a + 0.5);
630
+ if (ni < 0)
631
+ ni = 0;
632
+ if (ni >= k)
633
+ ni = k - 1;
634
+ err1 += sqr(xi - (ni * a + b));
635
+ sn += ni;
667
636
  sn2 += ni * ni;
668
637
  sxn += ni * xi;
669
638
  }
670
639
 
671
640
  if (err1 == last_err) {
672
- iter_last_err ++;
673
- if (iter_last_err == 16) break;
641
+ iter_last_err++;
642
+ if (iter_last_err == 16)
643
+ break;
674
644
  } else {
675
645
  last_err = err1;
676
646
  iter_last_err = 0;
677
647
  }
678
648
 
679
- float det = sqr (sn) - sn2 * n;
649
+ float det = sqr(sn) - sn2 * n;
680
650
 
681
651
  b = (sn * sxn - sn2 * sx) / det;
682
652
  a = (sn * sx - n * sxn) / det;
683
653
  if (verbose) {
684
- printf ("it %d, err1=%g \r", it, err1);
654
+ printf("it %d, err1=%g \r", it, err1);
685
655
  fflush(stdout);
686
656
  }
687
657
  }
688
- if (verbose) printf("\n");
658
+ if (verbose)
659
+ printf("\n");
689
660
 
690
661
  vmin = b;
691
662
  vmax = b + a * (k - 1);
692
663
 
693
664
  } else {
694
- FAISS_THROW_MSG ("Invalid qtype");
665
+ FAISS_THROW_MSG("Invalid qtype");
695
666
  }
696
667
  vmax -= vmin;
697
668
  }
698
669
 
699
- void train_NonUniform(RangeStat rs, float rs_arg,
700
- idx_t n, int d, int k, const float *x,
701
- std::vector<float> & trained)
702
- {
703
-
704
- trained.resize (2 * d);
705
- float * vmin = trained.data();
706
- float * vmax = trained.data() + d;
670
+ void train_NonUniform(
671
+ RangeStat rs,
672
+ float rs_arg,
673
+ idx_t n,
674
+ int d,
675
+ int k,
676
+ const float* x,
677
+ std::vector<float>& trained) {
678
+ trained.resize(2 * d);
679
+ float* vmin = trained.data();
680
+ float* vmax = trained.data() + d;
707
681
  if (rs == ScalarQuantizer::RS_minmax) {
708
- memcpy (vmin, x, sizeof(*x) * d);
709
- memcpy (vmax, x, sizeof(*x) * d);
682
+ memcpy(vmin, x, sizeof(*x) * d);
683
+ memcpy(vmax, x, sizeof(*x) * d);
710
684
  for (size_t i = 1; i < n; i++) {
711
- const float *xi = x + i * d;
685
+ const float* xi = x + i * d;
712
686
  for (size_t j = 0; j < d; j++) {
713
- if (xi[j] < vmin[j]) vmin[j] = xi[j];
714
- if (xi[j] > vmax[j]) vmax[j] = xi[j];
687
+ if (xi[j] < vmin[j])
688
+ vmin[j] = xi[j];
689
+ if (xi[j] > vmax[j])
690
+ vmax[j] = xi[j];
715
691
  }
716
692
  }
717
- float *vdiff = vmax;
693
+ float* vdiff = vmax;
718
694
  for (size_t j = 0; j < d; j++) {
719
695
  float vexp = (vmax[j] - vmin[j]) * rs_arg;
720
696
  vmin[j] -= vexp;
@@ -725,7 +701,7 @@ void train_NonUniform(RangeStat rs, float rs_arg,
725
701
  // transpose
726
702
  std::vector<float> xt(n * d);
727
703
  for (size_t i = 1; i < n; i++) {
728
- const float *xi = x + i * d;
704
+ const float* xi = x + i * d;
729
705
  for (size_t j = 0; j < d; j++) {
730
706
  xt[j * n + i] = xi[j];
731
707
  }
@@ -733,108 +709,98 @@ void train_NonUniform(RangeStat rs, float rs_arg,
733
709
  std::vector<float> trained_d(2);
734
710
  #pragma omp parallel for
735
711
  for (int j = 0; j < d; j++) {
736
- train_Uniform(rs, rs_arg,
737
- n, k, xt.data() + j * n,
738
- trained_d);
712
+ train_Uniform(rs, rs_arg, n, k, xt.data() + j * n, trained_d);
739
713
  vmin[j] = trained_d[0];
740
714
  vmax[j] = trained_d[1];
741
715
  }
742
716
  }
743
717
  }
744
718
 
745
-
746
-
747
719
  /*******************************************************************
748
720
  * Similarity: gets vector components and computes a similarity wrt. a
749
721
  * query vector stored in the object. The data fields just encapsulate
750
722
  * an accumulator.
751
723
  */
752
724
 
753
- template<int SIMDWIDTH>
725
+ template <int SIMDWIDTH>
754
726
  struct SimilarityL2 {};
755
727
 
756
-
757
- template<>
728
+ template <>
758
729
  struct SimilarityL2<1> {
759
730
  static constexpr int simdwidth = 1;
760
731
  static constexpr MetricType metric_type = METRIC_L2;
761
732
 
762
733
  const float *y, *yi;
763
734
 
764
- explicit SimilarityL2 (const float * y): y(y) {}
735
+ explicit SimilarityL2(const float* y) : y(y) {}
765
736
 
766
737
  /******* scalar accumulator *******/
767
738
 
768
739
  float accu;
769
740
 
770
- void begin () {
741
+ void begin() {
771
742
  accu = 0;
772
743
  yi = y;
773
744
  }
774
745
 
775
- void add_component (float x) {
746
+ void add_component(float x) {
776
747
  float tmp = *yi++ - x;
777
748
  accu += tmp * tmp;
778
749
  }
779
750
 
780
- void add_component_2 (float x1, float x2) {
751
+ void add_component_2(float x1, float x2) {
781
752
  float tmp = x1 - x2;
782
753
  accu += tmp * tmp;
783
754
  }
784
755
 
785
- float result () {
756
+ float result() {
786
757
  return accu;
787
758
  }
788
759
  };
789
760
 
790
-
791
761
  #ifdef __AVX2__
792
- template<>
762
+ template <>
793
763
  struct SimilarityL2<8> {
794
764
  static constexpr int simdwidth = 8;
795
765
  static constexpr MetricType metric_type = METRIC_L2;
796
766
 
797
767
  const float *y, *yi;
798
768
 
799
- explicit SimilarityL2 (const float * y): y(y) {}
769
+ explicit SimilarityL2(const float* y) : y(y) {}
800
770
  __m256 accu8;
801
771
 
802
- void begin_8 () {
772
+ void begin_8() {
803
773
  accu8 = _mm256_setzero_ps();
804
774
  yi = y;
805
775
  }
806
776
 
807
- void add_8_components (__m256 x) {
808
- __m256 yiv = _mm256_loadu_ps (yi);
777
+ void add_8_components(__m256 x) {
778
+ __m256 yiv = _mm256_loadu_ps(yi);
809
779
  yi += 8;
810
- __m256 tmp = yiv - x;
811
- accu8 += tmp * tmp;
780
+ __m256 tmp = _mm256_sub_ps(yiv, x);
781
+ accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
812
782
  }
813
783
 
814
- void add_8_components_2 (__m256 x, __m256 y) {
815
- __m256 tmp = y - x;
816
- accu8 += tmp * tmp;
784
+ void add_8_components_2(__m256 x, __m256 y) {
785
+ __m256 tmp = _mm256_sub_ps(y, x);
786
+ accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
817
787
  }
818
788
 
819
- float result_8 () {
789
+ float result_8() {
820
790
  __m256 sum = _mm256_hadd_ps(accu8, accu8);
821
791
  __m256 sum2 = _mm256_hadd_ps(sum, sum);
822
792
  // now add the 0th and 4th component
823
- return
824
- _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
825
- _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
793
+ return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
794
+ _mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
826
795
  }
827
-
828
796
  };
829
797
 
830
798
  #endif
831
799
 
832
-
833
- template<int SIMDWIDTH>
800
+ template <int SIMDWIDTH>
834
801
  struct SimilarityIP {};
835
802
 
836
-
837
- template<>
803
+ template <>
838
804
  struct SimilarityIP<1> {
839
805
  static constexpr int simdwidth = 1;
840
806
  static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
@@ -842,30 +808,29 @@ struct SimilarityIP<1> {
842
808
 
843
809
  float accu;
844
810
 
845
- explicit SimilarityIP (const float * y):
846
- y (y) {}
811
+ explicit SimilarityIP(const float* y) : y(y) {}
847
812
 
848
- void begin () {
813
+ void begin() {
849
814
  accu = 0;
850
815
  yi = y;
851
816
  }
852
817
 
853
- void add_component (float x) {
854
- accu += *yi++ * x;
818
+ void add_component(float x) {
819
+ accu += *yi++ * x;
855
820
  }
856
821
 
857
- void add_component_2 (float x1, float x2) {
858
- accu += x1 * x2;
822
+ void add_component_2(float x1, float x2) {
823
+ accu += x1 * x2;
859
824
  }
860
825
 
861
- float result () {
826
+ float result() {
862
827
  return accu;
863
828
  }
864
829
  };
865
830
 
866
831
  #ifdef __AVX2__
867
832
 
868
- template<>
833
+ template <>
869
834
  struct SimilarityIP<8> {
870
835
  static constexpr int simdwidth = 8;
871
836
  static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
@@ -874,59 +839,53 @@ struct SimilarityIP<8> {
874
839
 
875
840
  float accu;
876
841
 
877
- explicit SimilarityIP (const float * y):
878
- y (y) {}
842
+ explicit SimilarityIP(const float* y) : y(y) {}
879
843
 
880
844
  __m256 accu8;
881
845
 
882
- void begin_8 () {
846
+ void begin_8() {
883
847
  accu8 = _mm256_setzero_ps();
884
848
  yi = y;
885
849
  }
886
850
 
887
- void add_8_components (__m256 x) {
888
- __m256 yiv = _mm256_loadu_ps (yi);
851
+ void add_8_components(__m256 x) {
852
+ __m256 yiv = _mm256_loadu_ps(yi);
889
853
  yi += 8;
890
- accu8 += yiv * x;
854
+ accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(yiv, x));
891
855
  }
892
856
 
893
- void add_8_components_2 (__m256 x1, __m256 x2) {
894
- accu8 += x1 * x2;
857
+ void add_8_components_2(__m256 x1, __m256 x2) {
858
+ accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(x1, x2));
895
859
  }
896
860
 
897
- float result_8 () {
861
+ float result_8() {
898
862
  __m256 sum = _mm256_hadd_ps(accu8, accu8);
899
863
  __m256 sum2 = _mm256_hadd_ps(sum, sum);
900
864
  // now add the 0th and 4th component
901
- return
902
- _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
903
- _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
865
+ return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
866
+ _mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
904
867
  }
905
868
  };
906
869
  #endif
907
870
 
908
-
909
871
  /*******************************************************************
910
872
  * DistanceComputer: combines a similarity and a quantizer to do
911
873
  * code-to-vector or code-to-code comparisons
912
874
  *******************************************************************/
913
875
 
914
- template<class Quantizer, class Similarity, int SIMDWIDTH>
876
+ template <class Quantizer, class Similarity, int SIMDWIDTH>
915
877
  struct DCTemplate : SQDistanceComputer {};
916
878
 
917
- template<class Quantizer, class Similarity>
918
- struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
919
- {
879
+ template <class Quantizer, class Similarity>
880
+ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
920
881
  using Sim = Similarity;
921
882
 
922
883
  Quantizer quant;
923
884
 
924
- DCTemplate(size_t d, const std::vector<float> &trained):
925
- quant(d, trained)
926
- {}
885
+ DCTemplate(size_t d, const std::vector<float>& trained)
886
+ : quant(d, trained) {}
927
887
 
928
888
  float compute_distance(const float* x, const uint8_t* code) const {
929
-
930
889
  Similarity sim(x);
931
890
  sim.begin();
932
891
  for (size_t i = 0; i < quant.d; i++) {
@@ -937,7 +896,7 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
937
896
  }
938
897
 
939
898
  float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
940
- const {
899
+ const {
941
900
  Similarity sim(nullptr);
942
901
  sim.begin();
943
902
  for (size_t i = 0; i < quant.d; i++) {
@@ -948,41 +907,37 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
948
907
  return sim.result();
949
908
  }
950
909
 
951
- void set_query (const float *x) final {
910
+ void set_query(const float* x) final {
952
911
  q = x;
953
912
  }
954
913
 
955
914
  /// compute distance of vector i to current query
956
- float operator () (idx_t i) final {
957
- return compute_distance (q, codes + i * code_size);
915
+ float operator()(idx_t i) final {
916
+ return query_to_code(codes + i * code_size);
958
917
  }
959
918
 
960
- float symmetric_dis (idx_t i, idx_t j) override {
961
- return compute_code_distance (codes + i * code_size,
962
- codes + j * code_size);
919
+ float symmetric_dis(idx_t i, idx_t j) override {
920
+ return compute_code_distance(
921
+ codes + i * code_size, codes + j * code_size);
963
922
  }
964
923
 
965
- float query_to_code (const uint8_t * code) const {
966
- return compute_distance (q, code);
924
+ float query_to_code(const uint8_t* code) const final {
925
+ return compute_distance(q, code);
967
926
  }
968
-
969
927
  };
970
928
 
971
929
  #ifdef USE_F16C
972
930
 
973
- template<class Quantizer, class Similarity>
974
- struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
975
- {
931
+ template <class Quantizer, class Similarity>
932
+ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
976
933
  using Sim = Similarity;
977
934
 
978
935
  Quantizer quant;
979
936
 
980
- DCTemplate(size_t d, const std::vector<float> &trained):
981
- quant(d, trained)
982
- {}
937
+ DCTemplate(size_t d, const std::vector<float>& trained)
938
+ : quant(d, trained) {}
983
939
 
984
940
  float compute_distance(const float* x, const uint8_t* code) const {
985
-
986
941
  Similarity sim(x);
987
942
  sim.begin_8();
988
943
  for (size_t i = 0; i < quant.d; i += 8) {
@@ -993,7 +948,7 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
993
948
  }
994
949
 
995
950
  float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
996
- const {
951
+ const {
997
952
  Similarity sim(nullptr);
998
953
  sim.begin_8();
999
954
  for (size_t i = 0; i < quant.d; i += 8) {
@@ -1004,49 +959,45 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
1004
959
  return sim.result_8();
1005
960
  }
1006
961
 
1007
- void set_query (const float *x) final {
962
+ void set_query(const float* x) final {
1008
963
  q = x;
1009
964
  }
1010
965
 
1011
966
  /// compute distance of vector i to current query
1012
- float operator () (idx_t i) final {
1013
- return compute_distance (q, codes + i * code_size);
967
+ float operator()(idx_t i) final {
968
+ return query_to_code(codes + i * code_size);
1014
969
  }
1015
970
 
1016
- float symmetric_dis (idx_t i, idx_t j) override {
1017
- return compute_code_distance (codes + i * code_size,
1018
- codes + j * code_size);
971
+ float symmetric_dis(idx_t i, idx_t j) override {
972
+ return compute_code_distance(
973
+ codes + i * code_size, codes + j * code_size);
1019
974
  }
1020
975
 
1021
- float query_to_code (const uint8_t * code) const {
1022
- return compute_distance (q, code);
976
+ float query_to_code(const uint8_t* code) const final {
977
+ return compute_distance(q, code);
1023
978
  }
1024
-
1025
979
  };
1026
980
 
1027
981
  #endif
1028
982
 
1029
-
1030
-
1031
983
  /*******************************************************************
1032
984
  * DistanceComputerByte: computes distances in the integer domain
1033
985
  *******************************************************************/
1034
986
 
1035
- template<class Similarity, int SIMDWIDTH>
987
+ template <class Similarity, int SIMDWIDTH>
1036
988
  struct DistanceComputerByte : SQDistanceComputer {};
1037
989
 
1038
- template<class Similarity>
990
+ template <class Similarity>
1039
991
  struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1040
992
  using Sim = Similarity;
1041
993
 
1042
994
  int d;
1043
995
  std::vector<uint8_t> tmp;
1044
996
 
1045
- DistanceComputerByte(int d, const std::vector<float> &): d(d), tmp(d) {
1046
- }
997
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1047
998
 
1048
999
  int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1049
- const {
1000
+ const {
1050
1001
  int accu = 0;
1051
1002
  for (int i = 0; i < d; i++) {
1052
1003
  if (Sim::metric_type == METRIC_INNER_PRODUCT) {
@@ -1059,7 +1010,7 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1059
1010
  return accu;
1060
1011
  }
1061
1012
 
1062
- void set_query (const float *x) final {
1013
+ void set_query(const float* x) final {
1063
1014
  for (int i = 0; i < d; i++) {
1064
1015
  tmp[i] = int(x[i]);
1065
1016
  }
@@ -1071,44 +1022,41 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1071
1022
  }
1072
1023
 
1073
1024
  /// compute distance of vector i to current query
1074
- float operator () (idx_t i) final {
1075
- return compute_distance (q, codes + i * code_size);
1025
+ float operator()(idx_t i) final {
1026
+ return query_to_code(codes + i * code_size);
1076
1027
  }
1077
1028
 
1078
- float symmetric_dis (idx_t i, idx_t j) override {
1079
- return compute_code_distance (codes + i * code_size,
1080
- codes + j * code_size);
1029
+ float symmetric_dis(idx_t i, idx_t j) override {
1030
+ return compute_code_distance(
1031
+ codes + i * code_size, codes + j * code_size);
1081
1032
  }
1082
1033
 
1083
- float query_to_code (const uint8_t * code) const {
1084
- return compute_code_distance (tmp.data(), code);
1034
+ float query_to_code(const uint8_t* code) const final {
1035
+ return compute_code_distance(tmp.data(), code);
1085
1036
  }
1086
-
1087
1037
  };
1088
1038
 
1089
1039
  #ifdef __AVX2__
1090
1040
 
1091
-
1092
- template<class Similarity>
1041
+ template <class Similarity>
1093
1042
  struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1094
1043
  using Sim = Similarity;
1095
1044
 
1096
1045
  int d;
1097
1046
  std::vector<uint8_t> tmp;
1098
1047
 
1099
- DistanceComputerByte(int d, const std::vector<float> &): d(d), tmp(d) {
1100
- }
1048
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1101
1049
 
1102
1050
  int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1103
- const {
1051
+ const {
1104
1052
  // __m256i accu = _mm256_setzero_ps ();
1105
- __m256i accu = _mm256_setzero_si256 ();
1053
+ __m256i accu = _mm256_setzero_si256();
1106
1054
  for (int i = 0; i < d; i += 16) {
1107
1055
  // load 16 bytes, convert to 16 uint16_t
1108
- __m256i c1 = _mm256_cvtepu8_epi16
1109
- (_mm_loadu_si128((__m128i*)(code1 + i)));
1110
- __m256i c2 = _mm256_cvtepu8_epi16
1111
- (_mm_loadu_si128((__m128i*)(code2 + i)));
1056
+ __m256i c1 = _mm256_cvtepu8_epi16(
1057
+ _mm_loadu_si128((__m128i*)(code1 + i)));
1058
+ __m256i c2 = _mm256_cvtepu8_epi16(
1059
+ _mm_loadu_si128((__m128i*)(code2 + i)));
1112
1060
  __m256i prod32;
1113
1061
  if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1114
1062
  prod32 = _mm256_madd_epi16(c1, c2);
@@ -1116,17 +1064,16 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1116
1064
  __m256i diff = _mm256_sub_epi16(c1, c2);
1117
1065
  prod32 = _mm256_madd_epi16(diff, diff);
1118
1066
  }
1119
- accu = _mm256_add_epi32 (accu, prod32);
1120
-
1067
+ accu = _mm256_add_epi32(accu, prod32);
1121
1068
  }
1122
1069
  __m128i sum = _mm256_extractf128_si256(accu, 0);
1123
- sum = _mm_add_epi32 (sum, _mm256_extractf128_si256(accu, 1));
1124
- sum = _mm_hadd_epi32 (sum, sum);
1125
- sum = _mm_hadd_epi32 (sum, sum);
1126
- return _mm_cvtsi128_si32 (sum);
1070
+ sum = _mm_add_epi32(sum, _mm256_extractf128_si256(accu, 1));
1071
+ sum = _mm_hadd_epi32(sum, sum);
1072
+ sum = _mm_hadd_epi32(sum, sum);
1073
+ return _mm_cvtsi128_si32(sum);
1127
1074
  }
1128
1075
 
1129
- void set_query (const float *x) final {
1076
+ void set_query(const float* x) final {
1130
1077
  /*
1131
1078
  for (int i = 0; i < d; i += 8) {
1132
1079
  __m256 xi = _mm256_loadu_ps (x + i);
@@ -1143,20 +1090,18 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1143
1090
  }
1144
1091
 
1145
1092
  /// compute distance of vector i to current query
1146
- float operator () (idx_t i) final {
1147
- return compute_distance (q, codes + i * code_size);
1093
+ float operator()(idx_t i) final {
1094
+ return query_to_code(codes + i * code_size);
1148
1095
  }
1149
1096
 
1150
- float symmetric_dis (idx_t i, idx_t j) override {
1151
- return compute_code_distance (codes + i * code_size,
1152
- codes + j * code_size);
1097
+ float symmetric_dis(idx_t i, idx_t j) override {
1098
+ return compute_code_distance(
1099
+ codes + i * code_size, codes + j * code_size);
1153
1100
  }
1154
1101
 
1155
- float query_to_code (const uint8_t * code) const {
1156
- return compute_code_distance (tmp.data(), code);
1102
+ float query_to_code(const uint8_t* code) const final {
1103
+ return compute_code_distance(tmp.data(), code);
1157
1104
  }
1158
-
1159
-
1160
1105
  };
1161
1106
 
1162
1107
  #endif
@@ -1166,215 +1111,218 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1166
1111
  * specialization
1167
1112
  *******************************************************************/
1168
1113
 
1169
-
1170
- template<class Sim>
1171
- SQDistanceComputer *select_distance_computer (
1172
- QuantizerType qtype,
1173
- size_t d, const std::vector<float> & trained)
1174
- {
1114
+ template <class Sim>
1115
+ SQDistanceComputer* select_distance_computer(
1116
+ QuantizerType qtype,
1117
+ size_t d,
1118
+ const std::vector<float>& trained) {
1175
1119
  constexpr int SIMDWIDTH = Sim::simdwidth;
1176
- switch(qtype) {
1177
- case ScalarQuantizer::QT_8bit_uniform:
1178
- return new DCTemplate<QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1179
- Sim, SIMDWIDTH>(d, trained);
1180
-
1181
- case ScalarQuantizer::QT_4bit_uniform:
1182
- return new DCTemplate<QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1183
- Sim, SIMDWIDTH>(d, trained);
1184
-
1185
- case ScalarQuantizer::QT_8bit:
1186
- return new DCTemplate<QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1187
- Sim, SIMDWIDTH>(d, trained);
1188
-
1189
- case ScalarQuantizer::QT_6bit:
1190
- return new DCTemplate<QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1191
- Sim, SIMDWIDTH>(d, trained);
1192
-
1193
- case ScalarQuantizer::QT_4bit:
1194
- return new DCTemplate<QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1195
- Sim, SIMDWIDTH>(d, trained);
1196
-
1197
- case ScalarQuantizer::QT_fp16:
1198
- return new DCTemplate
1199
- <QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained);
1200
-
1201
- case ScalarQuantizer::QT_8bit_direct:
1202
- if (d % 16 == 0) {
1203
- return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1204
- } else {
1205
- return new DCTemplate
1206
- <Quantizer8bitDirect<SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained);
1207
- }
1120
+ switch (qtype) {
1121
+ case ScalarQuantizer::QT_8bit_uniform:
1122
+ return new DCTemplate<
1123
+ QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1124
+ Sim,
1125
+ SIMDWIDTH>(d, trained);
1126
+
1127
+ case ScalarQuantizer::QT_4bit_uniform:
1128
+ return new DCTemplate<
1129
+ QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1130
+ Sim,
1131
+ SIMDWIDTH>(d, trained);
1132
+
1133
+ case ScalarQuantizer::QT_8bit:
1134
+ return new DCTemplate<
1135
+ QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1136
+ Sim,
1137
+ SIMDWIDTH>(d, trained);
1138
+
1139
+ case ScalarQuantizer::QT_6bit:
1140
+ return new DCTemplate<
1141
+ QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1142
+ Sim,
1143
+ SIMDWIDTH>(d, trained);
1144
+
1145
+ case ScalarQuantizer::QT_4bit:
1146
+ return new DCTemplate<
1147
+ QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1148
+ Sim,
1149
+ SIMDWIDTH>(d, trained);
1150
+
1151
+ case ScalarQuantizer::QT_fp16:
1152
+ return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1153
+ d, trained);
1154
+
1155
+ case ScalarQuantizer::QT_8bit_direct:
1156
+ if (d % 16 == 0) {
1157
+ return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1158
+ } else {
1159
+ return new DCTemplate<
1160
+ Quantizer8bitDirect<SIMDWIDTH>,
1161
+ Sim,
1162
+ SIMDWIDTH>(d, trained);
1163
+ }
1208
1164
  }
1209
- FAISS_THROW_MSG ("unknown qtype");
1165
+ FAISS_THROW_MSG("unknown qtype");
1210
1166
  return nullptr;
1211
1167
  }
1212
1168
 
1213
-
1214
-
1215
1169
  } // anonymous namespace
1216
1170
 
1217
-
1218
-
1219
1171
  /*******************************************************************
1220
1172
  * ScalarQuantizer implementation
1221
1173
  ********************************************************************/
1222
1174
 
1223
-
1224
-
1225
- ScalarQuantizer::ScalarQuantizer
1226
- (size_t d, QuantizerType qtype):
1227
- qtype (qtype), rangestat(RS_minmax), rangestat_arg(0), d(d)
1228
- {
1229
- set_derived_sizes();
1175
+ ScalarQuantizer::ScalarQuantizer(size_t d, QuantizerType qtype)
1176
+ : qtype(qtype), rangestat(RS_minmax), rangestat_arg(0), d(d) {
1177
+ set_derived_sizes();
1230
1178
  }
1231
1179
 
1232
- ScalarQuantizer::ScalarQuantizer ():
1233
- qtype(QT_8bit),
1234
- rangestat(RS_minmax), rangestat_arg(0), d(0), bits(0), code_size(0)
1235
- {}
1180
+ ScalarQuantizer::ScalarQuantizer()
1181
+ : qtype(QT_8bit),
1182
+ rangestat(RS_minmax),
1183
+ rangestat_arg(0),
1184
+ d(0),
1185
+ bits(0),
1186
+ code_size(0) {}
1236
1187
 
1237
- void ScalarQuantizer::set_derived_sizes ()
1238
- {
1188
+ void ScalarQuantizer::set_derived_sizes() {
1239
1189
  switch (qtype) {
1240
- case QT_8bit:
1241
- case QT_8bit_uniform:
1242
- case QT_8bit_direct:
1243
- code_size = d;
1244
- bits = 8;
1245
- break;
1246
- case QT_4bit:
1247
- case QT_4bit_uniform:
1248
- code_size = (d + 1) / 2;
1249
- bits = 4;
1250
- break;
1251
- case QT_6bit:
1252
- code_size = (d * 6 + 7) / 8;
1253
- bits = 6;
1254
- break;
1255
- case QT_fp16:
1256
- code_size = d * 2;
1257
- bits = 16;
1258
- break;
1190
+ case QT_8bit:
1191
+ case QT_8bit_uniform:
1192
+ case QT_8bit_direct:
1193
+ code_size = d;
1194
+ bits = 8;
1195
+ break;
1196
+ case QT_4bit:
1197
+ case QT_4bit_uniform:
1198
+ code_size = (d + 1) / 2;
1199
+ bits = 4;
1200
+ break;
1201
+ case QT_6bit:
1202
+ code_size = (d * 6 + 7) / 8;
1203
+ bits = 6;
1204
+ break;
1205
+ case QT_fp16:
1206
+ code_size = d * 2;
1207
+ bits = 16;
1208
+ break;
1259
1209
  }
1260
1210
  }
1261
1211
 
1262
- void ScalarQuantizer::train (size_t n, const float *x)
1263
- {
1264
- int bit_per_dim =
1265
- qtype == QT_4bit_uniform ? 4 :
1266
- qtype == QT_4bit ? 4 :
1267
- qtype == QT_6bit ? 6 :
1268
- qtype == QT_8bit_uniform ? 8 :
1269
- qtype == QT_8bit ? 8 : -1;
1212
+ void ScalarQuantizer::train(size_t n, const float* x) {
1213
+ int bit_per_dim = qtype == QT_4bit_uniform ? 4
1214
+ : qtype == QT_4bit ? 4
1215
+ : qtype == QT_6bit ? 6
1216
+ : qtype == QT_8bit_uniform ? 8
1217
+ : qtype == QT_8bit ? 8
1218
+ : -1;
1270
1219
 
1271
1220
  switch (qtype) {
1272
- case QT_4bit_uniform: case QT_8bit_uniform:
1273
- train_Uniform (rangestat, rangestat_arg,
1274
- n * d, 1 << bit_per_dim, x, trained);
1275
- break;
1276
- case QT_4bit: case QT_8bit: case QT_6bit:
1277
- train_NonUniform (rangestat, rangestat_arg,
1278
- n, d, 1 << bit_per_dim, x, trained);
1279
- break;
1280
- case QT_fp16:
1281
- case QT_8bit_direct:
1282
- // no training necessary
1283
- break;
1221
+ case QT_4bit_uniform:
1222
+ case QT_8bit_uniform:
1223
+ train_Uniform(
1224
+ rangestat,
1225
+ rangestat_arg,
1226
+ n * d,
1227
+ 1 << bit_per_dim,
1228
+ x,
1229
+ trained);
1230
+ break;
1231
+ case QT_4bit:
1232
+ case QT_8bit:
1233
+ case QT_6bit:
1234
+ train_NonUniform(
1235
+ rangestat,
1236
+ rangestat_arg,
1237
+ n,
1238
+ d,
1239
+ 1 << bit_per_dim,
1240
+ x,
1241
+ trained);
1242
+ break;
1243
+ case QT_fp16:
1244
+ case QT_8bit_direct:
1245
+ // no training necessary
1246
+ break;
1284
1247
  }
1285
1248
  }
1286
1249
 
1287
- void ScalarQuantizer::train_residual(size_t n,
1288
- const float *x,
1289
- Index *quantizer,
1290
- bool by_residual,
1291
- bool verbose)
1292
- {
1293
- const float * x_in = x;
1250
+ void ScalarQuantizer::train_residual(
1251
+ size_t n,
1252
+ const float* x,
1253
+ Index* quantizer,
1254
+ bool by_residual,
1255
+ bool verbose) {
1256
+ const float* x_in = x;
1294
1257
 
1295
1258
  // 100k points more than enough
1296
- x = fvecs_maybe_subsample (
1297
- d, (size_t*)&n, 100000,
1298
- x, verbose, 1234);
1259
+ x = fvecs_maybe_subsample(d, (size_t*)&n, 100000, x, verbose, 1234);
1299
1260
 
1300
- ScopeDeleter<float> del_x (x_in == x ? nullptr : x);
1261
+ ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
1301
1262
 
1302
1263
  if (by_residual) {
1303
1264
  std::vector<Index::idx_t> idx(n);
1304
- quantizer->assign (n, x, idx.data());
1265
+ quantizer->assign(n, x, idx.data());
1305
1266
 
1306
1267
  std::vector<float> residuals(n * d);
1307
- quantizer->compute_residual_n (n, x, residuals.data(), idx.data());
1268
+ quantizer->compute_residual_n(n, x, residuals.data(), idx.data());
1308
1269
 
1309
- train (n, residuals.data());
1270
+ train(n, residuals.data());
1310
1271
  } else {
1311
- train (n, x);
1272
+ train(n, x);
1312
1273
  }
1313
1274
  }
1314
1275
 
1315
-
1316
- ScalarQuantizer::Quantizer *ScalarQuantizer::select_quantizer () const
1317
- {
1276
+ ScalarQuantizer::Quantizer* ScalarQuantizer::select_quantizer() const {
1318
1277
  #ifdef USE_F16C
1319
1278
  if (d % 8 == 0) {
1320
- return select_quantizer_1<8> (qtype, d, trained);
1279
+ return select_quantizer_1<8>(qtype, d, trained);
1321
1280
  } else
1322
1281
  #endif
1323
1282
  {
1324
- return select_quantizer_1<1> (qtype, d, trained);
1283
+ return select_quantizer_1<1>(qtype, d, trained);
1325
1284
  }
1326
1285
  }
1327
1286
 
1287
+ void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
1288
+ const {
1289
+ std::unique_ptr<Quantizer> squant(select_quantizer());
1328
1290
 
1329
- void ScalarQuantizer::compute_codes (const float * x,
1330
- uint8_t * codes,
1331
- size_t n) const
1332
- {
1333
- std::unique_ptr<Quantizer> squant(select_quantizer ());
1334
-
1335
- memset (codes, 0, code_size * n);
1291
+ memset(codes, 0, code_size * n);
1336
1292
  #pragma omp parallel for
1337
1293
  for (int64_t i = 0; i < n; i++)
1338
- squant->encode_vector (x + i * d, codes + i * code_size);
1294
+ squant->encode_vector(x + i * d, codes + i * code_size);
1339
1295
  }
1340
1296
 
1341
- void ScalarQuantizer::decode (const uint8_t *codes, float *x, size_t n) const
1342
- {
1343
- std::unique_ptr<Quantizer> squant(select_quantizer ());
1297
+ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1298
+ std::unique_ptr<Quantizer> squant(select_quantizer());
1344
1299
 
1345
1300
  #pragma omp parallel for
1346
1301
  for (int64_t i = 0; i < n; i++)
1347
- squant->decode_vector (codes + i * code_size, x + i * d);
1302
+ squant->decode_vector(codes + i * code_size, x + i * d);
1348
1303
  }
1349
1304
 
1350
-
1351
- SQDistanceComputer *
1352
- ScalarQuantizer::get_distance_computer (MetricType metric) const
1353
- {
1305
+ SQDistanceComputer* ScalarQuantizer::get_distance_computer(
1306
+ MetricType metric) const {
1354
1307
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1355
1308
  #ifdef USE_F16C
1356
1309
  if (d % 8 == 0) {
1357
1310
  if (metric == METRIC_L2) {
1358
- return select_distance_computer<SimilarityL2<8> >
1359
- (qtype, d, trained);
1311
+ return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
1360
1312
  } else {
1361
- return select_distance_computer<SimilarityIP<8> >
1362
- (qtype, d, trained);
1313
+ return select_distance_computer<SimilarityIP<8>>(qtype, d, trained);
1363
1314
  }
1364
1315
  } else
1365
1316
  #endif
1366
1317
  {
1367
1318
  if (metric == METRIC_L2) {
1368
- return select_distance_computer<SimilarityL2<1> >
1369
- (qtype, d, trained);
1319
+ return select_distance_computer<SimilarityL2<1>>(qtype, d, trained);
1370
1320
  } else {
1371
- return select_distance_computer<SimilarityIP<1> >
1372
- (qtype, d, trained);
1321
+ return select_distance_computer<SimilarityIP<1>>(qtype, d, trained);
1373
1322
  }
1374
1323
  }
1375
1324
  }
1376
1325
 
1377
-
1378
1326
  /*******************************************************************
1379
1327
  * IndexScalarQuantizer/IndexIVFScalarQuantizer scanner object
1380
1328
  *
@@ -1384,54 +1332,57 @@ ScalarQuantizer::get_distance_computer (MetricType metric) const
1384
1332
 
1385
1333
  namespace {
1386
1334
 
1387
-
1388
- template<class DCClass>
1389
- struct IVFSQScannerIP: InvertedListScanner {
1335
+ template <class DCClass>
1336
+ struct IVFSQScannerIP : InvertedListScanner {
1390
1337
  DCClass dc;
1391
1338
  bool store_pairs, by_residual;
1392
1339
 
1393
1340
  size_t code_size;
1394
1341
 
1395
- idx_t list_no; /// current list (set to 0 for Flat index
1396
- float accu0; /// added to all distances
1397
-
1398
- IVFSQScannerIP(int d, const std::vector<float> & trained,
1399
- size_t code_size, bool store_pairs,
1400
- bool by_residual):
1401
- dc(d, trained), store_pairs(store_pairs),
1402
- by_residual(by_residual),
1403
- code_size(code_size), list_no(0), accu0(0)
1404
- {}
1342
+ idx_t list_no; /// current list (set to 0 for Flat index
1343
+ float accu0; /// added to all distances
1405
1344
 
1345
+ IVFSQScannerIP(
1346
+ int d,
1347
+ const std::vector<float>& trained,
1348
+ size_t code_size,
1349
+ bool store_pairs,
1350
+ bool by_residual)
1351
+ : dc(d, trained),
1352
+ store_pairs(store_pairs),
1353
+ by_residual(by_residual),
1354
+ code_size(code_size),
1355
+ list_no(0),
1356
+ accu0(0) {}
1406
1357
 
1407
- void set_query (const float *query) override {
1408
- dc.set_query (query);
1358
+ void set_query(const float* query) override {
1359
+ dc.set_query(query);
1409
1360
  }
1410
1361
 
1411
- void set_list (idx_t list_no, float coarse_dis) override {
1362
+ void set_list(idx_t list_no, float coarse_dis) override {
1412
1363
  this->list_no = list_no;
1413
1364
  accu0 = by_residual ? coarse_dis : 0;
1414
1365
  }
1415
1366
 
1416
- float distance_to_code (const uint8_t *code) const final {
1417
- return accu0 + dc.query_to_code (code);
1367
+ float distance_to_code(const uint8_t* code) const final {
1368
+ return accu0 + dc.query_to_code(code);
1418
1369
  }
1419
1370
 
1420
- size_t scan_codes (size_t list_size,
1421
- const uint8_t *codes,
1422
- const idx_t *ids,
1423
- float *simi, idx_t *idxi,
1424
- size_t k) const override
1425
- {
1371
+ size_t scan_codes(
1372
+ size_t list_size,
1373
+ const uint8_t* codes,
1374
+ const idx_t* ids,
1375
+ float* simi,
1376
+ idx_t* idxi,
1377
+ size_t k) const override {
1426
1378
  size_t nup = 0;
1427
1379
 
1428
1380
  for (size_t j = 0; j < list_size; j++) {
1381
+ float accu = accu0 + dc.query_to_code(codes);
1429
1382
 
1430
- float accu = accu0 + dc.query_to_code (codes);
1431
-
1432
- if (accu > simi [0]) {
1383
+ if (accu > simi[0]) {
1433
1384
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1434
- minheap_replace_top (k, simi, idxi, accu, id);
1385
+ minheap_replace_top(k, simi, idxi, accu, id);
1435
1386
  nup++;
1436
1387
  }
1437
1388
  codes += code_size;
@@ -1439,86 +1390,87 @@ struct IVFSQScannerIP: InvertedListScanner {
1439
1390
  return nup;
1440
1391
  }
1441
1392
 
1442
- void scan_codes_range (size_t list_size,
1443
- const uint8_t *codes,
1444
- const idx_t *ids,
1445
- float radius,
1446
- RangeQueryResult & res) const override
1447
- {
1393
+ void scan_codes_range(
1394
+ size_t list_size,
1395
+ const uint8_t* codes,
1396
+ const idx_t* ids,
1397
+ float radius,
1398
+ RangeQueryResult& res) const override {
1448
1399
  for (size_t j = 0; j < list_size; j++) {
1449
- float accu = accu0 + dc.query_to_code (codes);
1400
+ float accu = accu0 + dc.query_to_code(codes);
1450
1401
  if (accu > radius) {
1451
1402
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1452
- res.add (accu, id);
1403
+ res.add(accu, id);
1453
1404
  }
1454
1405
  codes += code_size;
1455
1406
  }
1456
1407
  }
1457
-
1458
-
1459
1408
  };
1460
1409
 
1461
-
1462
- template<class DCClass>
1463
- struct IVFSQScannerL2: InvertedListScanner {
1464
-
1410
+ template <class DCClass>
1411
+ struct IVFSQScannerL2 : InvertedListScanner {
1465
1412
  DCClass dc;
1466
1413
 
1467
1414
  bool store_pairs, by_residual;
1468
1415
  size_t code_size;
1469
- const Index *quantizer;
1470
- idx_t list_no; /// current inverted list
1471
- const float *x; /// current query
1416
+ const Index* quantizer;
1417
+ idx_t list_no; /// current inverted list
1418
+ const float* x; /// current query
1472
1419
 
1473
1420
  std::vector<float> tmp;
1474
1421
 
1475
- IVFSQScannerL2(int d, const std::vector<float> & trained,
1476
- size_t code_size, const Index *quantizer,
1477
- bool store_pairs, bool by_residual):
1478
- dc(d, trained), store_pairs(store_pairs), by_residual(by_residual),
1479
- code_size(code_size), quantizer(quantizer),
1480
- list_no (0), x (nullptr), tmp (d)
1481
- {
1482
- }
1483
-
1484
-
1485
- void set_query (const float *query) override {
1422
+ IVFSQScannerL2(
1423
+ int d,
1424
+ const std::vector<float>& trained,
1425
+ size_t code_size,
1426
+ const Index* quantizer,
1427
+ bool store_pairs,
1428
+ bool by_residual)
1429
+ : dc(d, trained),
1430
+ store_pairs(store_pairs),
1431
+ by_residual(by_residual),
1432
+ code_size(code_size),
1433
+ quantizer(quantizer),
1434
+ list_no(0),
1435
+ x(nullptr),
1436
+ tmp(d) {}
1437
+
1438
+ void set_query(const float* query) override {
1486
1439
  x = query;
1487
1440
  if (!quantizer) {
1488
- dc.set_query (query);
1441
+ dc.set_query(query);
1489
1442
  }
1490
1443
  }
1491
1444
 
1492
-
1493
- void set_list (idx_t list_no, float /*coarse_dis*/) override {
1445
+ void set_list(idx_t list_no, float /*coarse_dis*/) override {
1494
1446
  if (by_residual) {
1495
1447
  this->list_no = list_no;
1496
1448
  // shift of x_in wrt centroid
1497
- quantizer->compute_residual (x, tmp.data(), list_no);
1498
- dc.set_query (tmp.data ());
1449
+ quantizer->compute_residual(x, tmp.data(), list_no);
1450
+ dc.set_query(tmp.data());
1499
1451
  } else {
1500
- dc.set_query (x);
1452
+ dc.set_query(x);
1501
1453
  }
1502
1454
  }
1503
1455
 
1504
- float distance_to_code (const uint8_t *code) const final {
1505
- return dc.query_to_code (code);
1456
+ float distance_to_code(const uint8_t* code) const final {
1457
+ return dc.query_to_code(code);
1506
1458
  }
1507
1459
 
1508
- size_t scan_codes (size_t list_size,
1509
- const uint8_t *codes,
1510
- const idx_t *ids,
1511
- float *simi, idx_t *idxi,
1512
- size_t k) const override
1513
- {
1460
+ size_t scan_codes(
1461
+ size_t list_size,
1462
+ const uint8_t* codes,
1463
+ const idx_t* ids,
1464
+ float* simi,
1465
+ idx_t* idxi,
1466
+ size_t k) const override {
1514
1467
  size_t nup = 0;
1515
1468
  for (size_t j = 0; j < list_size; j++) {
1469
+ float dis = dc.query_to_code(codes);
1516
1470
 
1517
- float dis = dc.query_to_code (codes);
1518
-
1519
- if (dis < simi [0]) {
1471
+ if (dis < simi[0]) {
1520
1472
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1521
- maxheap_replace_top (k, simi, idxi, dis, id);
1473
+ maxheap_replace_top(k, simi, idxi, dis, id);
1522
1474
  nup++;
1523
1475
  }
1524
1476
  codes += code_size;
@@ -1526,137 +1478,132 @@ struct IVFSQScannerL2: InvertedListScanner {
1526
1478
  return nup;
1527
1479
  }
1528
1480
 
1529
- void scan_codes_range (size_t list_size,
1530
- const uint8_t *codes,
1531
- const idx_t *ids,
1532
- float radius,
1533
- RangeQueryResult & res) const override
1534
- {
1481
+ void scan_codes_range(
1482
+ size_t list_size,
1483
+ const uint8_t* codes,
1484
+ const idx_t* ids,
1485
+ float radius,
1486
+ RangeQueryResult& res) const override {
1535
1487
  for (size_t j = 0; j < list_size; j++) {
1536
- float dis = dc.query_to_code (codes);
1488
+ float dis = dc.query_to_code(codes);
1537
1489
  if (dis < radius) {
1538
1490
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1539
- res.add (dis, id);
1491
+ res.add(dis, id);
1540
1492
  }
1541
1493
  codes += code_size;
1542
1494
  }
1543
1495
  }
1544
-
1545
-
1546
1496
  };
1547
1497
 
1548
- template<class DCClass>
1549
- InvertedListScanner* sel2_InvertedListScanner
1550
- (const ScalarQuantizer *sq,
1551
- const Index *quantizer, bool store_pairs, bool r)
1552
- {
1498
+ template <class DCClass>
1499
+ InvertedListScanner* sel2_InvertedListScanner(
1500
+ const ScalarQuantizer* sq,
1501
+ const Index* quantizer,
1502
+ bool store_pairs,
1503
+ bool r) {
1553
1504
  if (DCClass::Sim::metric_type == METRIC_L2) {
1554
- return new IVFSQScannerL2<DCClass>(sq->d, sq->trained, sq->code_size,
1555
- quantizer, store_pairs, r);
1505
+ return new IVFSQScannerL2<DCClass>(
1506
+ sq->d, sq->trained, sq->code_size, quantizer, store_pairs, r);
1556
1507
  } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
1557
- return new IVFSQScannerIP<DCClass>(sq->d, sq->trained, sq->code_size,
1558
- store_pairs, r);
1508
+ return new IVFSQScannerIP<DCClass>(
1509
+ sq->d, sq->trained, sq->code_size, store_pairs, r);
1559
1510
  } else {
1560
1511
  FAISS_THROW_MSG("unsupported metric type");
1561
1512
  }
1562
1513
  }
1563
1514
 
1564
- template<class Similarity, class Codec, bool uniform>
1565
- InvertedListScanner* sel12_InvertedListScanner
1566
- (const ScalarQuantizer *sq,
1567
- const Index *quantizer, bool store_pairs, bool r)
1568
- {
1515
+ template <class Similarity, class Codec, bool uniform>
1516
+ InvertedListScanner* sel12_InvertedListScanner(
1517
+ const ScalarQuantizer* sq,
1518
+ const Index* quantizer,
1519
+ bool store_pairs,
1520
+ bool r) {
1569
1521
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1570
1522
  using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
1571
1523
  using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
1572
- return sel2_InvertedListScanner<DCClass> (sq, quantizer, store_pairs, r);
1524
+ return sel2_InvertedListScanner<DCClass>(sq, quantizer, store_pairs, r);
1573
1525
  }
1574
1526
 
1575
-
1576
-
1577
- template<class Similarity>
1578
- InvertedListScanner* sel1_InvertedListScanner
1579
- (const ScalarQuantizer *sq, const Index *quantizer,
1580
- bool store_pairs, bool r)
1581
- {
1527
+ template <class Similarity>
1528
+ InvertedListScanner* sel1_InvertedListScanner(
1529
+ const ScalarQuantizer* sq,
1530
+ const Index* quantizer,
1531
+ bool store_pairs,
1532
+ bool r) {
1582
1533
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1583
- switch(sq->qtype) {
1584
- case ScalarQuantizer::QT_8bit_uniform:
1585
- return sel12_InvertedListScanner
1586
- <Similarity, Codec8bit, true>(sq, quantizer, store_pairs, r);
1587
- case ScalarQuantizer::QT_4bit_uniform:
1588
- return sel12_InvertedListScanner
1589
- <Similarity, Codec4bit, true>(sq, quantizer, store_pairs, r);
1590
- case ScalarQuantizer::QT_8bit:
1591
- return sel12_InvertedListScanner
1592
- <Similarity, Codec8bit, false>(sq, quantizer, store_pairs, r);
1593
- case ScalarQuantizer::QT_4bit:
1594
- return sel12_InvertedListScanner
1595
- <Similarity, Codec4bit, false>(sq, quantizer, store_pairs, r);
1596
- case ScalarQuantizer::QT_6bit:
1597
- return sel12_InvertedListScanner
1598
- <Similarity, Codec6bit, false>(sq, quantizer, store_pairs, r);
1599
- case ScalarQuantizer::QT_fp16:
1600
- return sel2_InvertedListScanner
1601
- <DCTemplate<QuantizerFP16<SIMDWIDTH>, Similarity, SIMDWIDTH> >
1602
- (sq, quantizer, store_pairs, r);
1603
- case ScalarQuantizer::QT_8bit_direct:
1604
- if (sq->d % 16 == 0) {
1605
- return sel2_InvertedListScanner
1606
- <DistanceComputerByte<Similarity, SIMDWIDTH> >
1607
- (sq, quantizer, store_pairs, r);
1608
- } else {
1609
- return sel2_InvertedListScanner
1610
- <DCTemplate<Quantizer8bitDirect<SIMDWIDTH>,
1611
- Similarity, SIMDWIDTH> >
1612
- (sq, quantizer, store_pairs, r);
1613
- }
1614
-
1534
+ switch (sq->qtype) {
1535
+ case ScalarQuantizer::QT_8bit_uniform:
1536
+ return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
1537
+ sq, quantizer, store_pairs, r);
1538
+ case ScalarQuantizer::QT_4bit_uniform:
1539
+ return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
1540
+ sq, quantizer, store_pairs, r);
1541
+ case ScalarQuantizer::QT_8bit:
1542
+ return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
1543
+ sq, quantizer, store_pairs, r);
1544
+ case ScalarQuantizer::QT_4bit:
1545
+ return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
1546
+ sq, quantizer, store_pairs, r);
1547
+ case ScalarQuantizer::QT_6bit:
1548
+ return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
1549
+ sq, quantizer, store_pairs, r);
1550
+ case ScalarQuantizer::QT_fp16:
1551
+ return sel2_InvertedListScanner<DCTemplate<
1552
+ QuantizerFP16<SIMDWIDTH>,
1553
+ Similarity,
1554
+ SIMDWIDTH>>(sq, quantizer, store_pairs, r);
1555
+ case ScalarQuantizer::QT_8bit_direct:
1556
+ if (sq->d % 16 == 0) {
1557
+ return sel2_InvertedListScanner<
1558
+ DistanceComputerByte<Similarity, SIMDWIDTH>>(
1559
+ sq, quantizer, store_pairs, r);
1560
+ } else {
1561
+ return sel2_InvertedListScanner<DCTemplate<
1562
+ Quantizer8bitDirect<SIMDWIDTH>,
1563
+ Similarity,
1564
+ SIMDWIDTH>>(sq, quantizer, store_pairs, r);
1565
+ }
1615
1566
  }
1616
1567
 
1617
- FAISS_THROW_MSG ("unknown qtype");
1568
+ FAISS_THROW_MSG("unknown qtype");
1618
1569
  return nullptr;
1619
1570
  }
1620
1571
 
1621
- template<int SIMDWIDTH>
1622
- InvertedListScanner* sel0_InvertedListScanner
1623
- (MetricType mt, const ScalarQuantizer *sq,
1624
- const Index *quantizer, bool store_pairs, bool by_residual)
1625
- {
1572
+ template <int SIMDWIDTH>
1573
+ InvertedListScanner* sel0_InvertedListScanner(
1574
+ MetricType mt,
1575
+ const ScalarQuantizer* sq,
1576
+ const Index* quantizer,
1577
+ bool store_pairs,
1578
+ bool by_residual) {
1626
1579
  if (mt == METRIC_L2) {
1627
- return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH> >
1628
- (sq, quantizer, store_pairs, by_residual);
1580
+ return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH>>(
1581
+ sq, quantizer, store_pairs, by_residual);
1629
1582
  } else if (mt == METRIC_INNER_PRODUCT) {
1630
- return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH> >
1631
- (sq, quantizer, store_pairs, by_residual);
1583
+ return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH>>(
1584
+ sq, quantizer, store_pairs, by_residual);
1632
1585
  } else {
1633
1586
  FAISS_THROW_MSG("unsupported metric type");
1634
1587
  }
1635
1588
  }
1636
1589
 
1637
-
1638
-
1639
1590
  } // anonymous namespace
1640
1591
 
1641
-
1642
- InvertedListScanner* ScalarQuantizer::select_InvertedListScanner
1643
- (MetricType mt, const Index *quantizer,
1644
- bool store_pairs, bool by_residual) const
1645
- {
1592
+ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1593
+ MetricType mt,
1594
+ const Index* quantizer,
1595
+ bool store_pairs,
1596
+ bool by_residual) const {
1646
1597
  #ifdef USE_F16C
1647
1598
  if (d % 8 == 0) {
1648
- return sel0_InvertedListScanner<8>
1649
- (mt, this, quantizer, store_pairs, by_residual);
1599
+ return sel0_InvertedListScanner<8>(
1600
+ mt, this, quantizer, store_pairs, by_residual);
1650
1601
  } else
1651
1602
  #endif
1652
1603
  {
1653
- return sel0_InvertedListScanner<1>
1654
- (mt, this, quantizer, store_pairs, by_residual);
1604
+ return sel0_InvertedListScanner<1>(
1605
+ mt, this, quantizer, store_pairs, by_residual);
1655
1606
  }
1656
1607
  }
1657
1608
 
1658
-
1659
-
1660
-
1661
-
1662
1609
  } // namespace faiss