faiss 0.2.0 → 0.2.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +16 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +7 -7
  5. data/ext/faiss/extconf.rb +6 -3
  6. data/ext/faiss/numo.hpp +4 -4
  7. data/ext/faiss/utils.cpp +1 -1
  8. data/ext/faiss/utils.h +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +292 -291
  11. data/vendor/faiss/faiss/AutoTune.h +55 -56
  12. data/vendor/faiss/faiss/Clustering.cpp +365 -194
  13. data/vendor/faiss/faiss/Clustering.h +102 -35
  14. data/vendor/faiss/faiss/IVFlib.cpp +171 -195
  15. data/vendor/faiss/faiss/IVFlib.h +48 -51
  16. data/vendor/faiss/faiss/Index.cpp +85 -103
  17. data/vendor/faiss/faiss/Index.h +54 -48
  18. data/vendor/faiss/faiss/Index2Layer.cpp +126 -224
  19. data/vendor/faiss/faiss/Index2Layer.h +22 -36
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +195 -0
  22. data/vendor/faiss/faiss/IndexBinary.cpp +45 -37
  23. data/vendor/faiss/faiss/IndexBinary.h +140 -132
  24. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +73 -53
  25. data/vendor/faiss/faiss/IndexBinaryFlat.h +29 -24
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +46 -43
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +16 -15
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +215 -232
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.h +25 -24
  30. data/vendor/faiss/faiss/IndexBinaryHash.cpp +182 -177
  31. data/vendor/faiss/faiss/IndexBinaryHash.h +41 -34
  32. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +489 -461
  33. data/vendor/faiss/faiss/IndexBinaryIVF.h +97 -68
  34. data/vendor/faiss/faiss/IndexFlat.cpp +115 -176
  35. data/vendor/faiss/faiss/IndexFlat.h +42 -59
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +372 -348
  39. data/vendor/faiss/faiss/IndexHNSW.h +57 -41
  40. data/vendor/faiss/faiss/IndexIVF.cpp +545 -453
  41. data/vendor/faiss/faiss/IndexIVF.h +169 -118
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
  43. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +247 -252
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +48 -51
  46. data/vendor/faiss/faiss/IndexIVFPQ.cpp +459 -517
  47. data/vendor/faiss/faiss/IndexIVFPQ.h +75 -67
  48. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +406 -372
  49. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +82 -57
  50. data/vendor/faiss/faiss/IndexIVFPQR.cpp +104 -102
  51. data/vendor/faiss/faiss/IndexIVFPQR.h +33 -28
  52. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +163 -150
  53. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +38 -25
  54. data/vendor/faiss/faiss/IndexLSH.cpp +66 -113
  55. data/vendor/faiss/faiss/IndexLSH.h +20 -38
  56. data/vendor/faiss/faiss/IndexLattice.cpp +42 -56
  57. data/vendor/faiss/faiss/IndexLattice.h +11 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +229 -0
  59. data/vendor/faiss/faiss/IndexNNDescent.h +72 -0
  60. data/vendor/faiss/faiss/IndexNSG.cpp +301 -0
  61. data/vendor/faiss/faiss/IndexNSG.h +85 -0
  62. data/vendor/faiss/faiss/IndexPQ.cpp +387 -495
  63. data/vendor/faiss/faiss/IndexPQ.h +64 -82
  64. data/vendor/faiss/faiss/IndexPQFastScan.cpp +143 -170
  65. data/vendor/faiss/faiss/IndexPQFastScan.h +46 -32
  66. data/vendor/faiss/faiss/IndexPreTransform.cpp +120 -150
  67. data/vendor/faiss/faiss/IndexPreTransform.h +33 -36
  68. data/vendor/faiss/faiss/IndexRefine.cpp +139 -127
  69. data/vendor/faiss/faiss/IndexRefine.h +32 -23
  70. data/vendor/faiss/faiss/IndexReplicas.cpp +147 -153
  71. data/vendor/faiss/faiss/IndexReplicas.h +62 -56
  72. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +111 -172
  73. data/vendor/faiss/faiss/IndexScalarQuantizer.h +41 -59
  74. data/vendor/faiss/faiss/IndexShards.cpp +256 -240
  75. data/vendor/faiss/faiss/IndexShards.h +85 -73
  76. data/vendor/faiss/faiss/MatrixStats.cpp +112 -97
  77. data/vendor/faiss/faiss/MatrixStats.h +7 -10
  78. data/vendor/faiss/faiss/MetaIndexes.cpp +135 -157
  79. data/vendor/faiss/faiss/MetaIndexes.h +40 -34
  80. data/vendor/faiss/faiss/MetricType.h +7 -7
  81. data/vendor/faiss/faiss/VectorTransform.cpp +654 -475
  82. data/vendor/faiss/faiss/VectorTransform.h +64 -89
  83. data/vendor/faiss/faiss/clone_index.cpp +78 -73
  84. data/vendor/faiss/faiss/clone_index.h +4 -9
  85. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +33 -38
  86. data/vendor/faiss/faiss/gpu/GpuAutoTune.h +11 -9
  87. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +198 -171
  88. data/vendor/faiss/faiss/gpu/GpuCloner.h +53 -35
  89. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +12 -14
  90. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +27 -25
  91. data/vendor/faiss/faiss/gpu/GpuDistance.h +116 -112
  92. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -2
  93. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  94. data/vendor/faiss/faiss/gpu/GpuIndex.h +134 -137
  95. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +76 -73
  96. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +173 -162
  97. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +67 -64
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +89 -86
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +150 -141
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +101 -103
  101. data/vendor/faiss/faiss/gpu/GpuIndicesOptions.h +17 -16
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +116 -128
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +182 -186
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +433 -422
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +131 -130
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +468 -456
  107. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.h +25 -19
  108. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +22 -20
  109. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +9 -8
  110. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +39 -44
  111. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +16 -14
  112. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +77 -71
  113. data/vendor/faiss/faiss/gpu/perf/PerfIVFPQAdd.cpp +109 -88
  114. data/vendor/faiss/faiss/gpu/perf/WriteIndex.cpp +75 -64
  115. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +230 -215
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +80 -86
  117. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +284 -277
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +416 -416
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +611 -517
  120. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +166 -164
  121. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +61 -53
  122. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +274 -238
  123. data/vendor/faiss/faiss/gpu/test/TestUtils.h +73 -57
  124. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +47 -50
  125. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +79 -72
  126. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.cpp +140 -146
  127. data/vendor/faiss/faiss/gpu/utils/StackDeviceMemory.h +69 -71
  128. data/vendor/faiss/faiss/gpu/utils/StaticUtils.h +21 -16
  129. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +25 -29
  130. data/vendor/faiss/faiss/gpu/utils/Timer.h +30 -29
  131. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +503 -0
  132. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +175 -0
  133. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +90 -120
  134. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +81 -65
  135. data/vendor/faiss/faiss/impl/FaissAssert.h +73 -58
  136. data/vendor/faiss/faiss/impl/FaissException.cpp +56 -48
  137. data/vendor/faiss/faiss/impl/FaissException.h +41 -29
  138. data/vendor/faiss/faiss/impl/HNSW.cpp +606 -617
  139. data/vendor/faiss/faiss/impl/HNSW.h +179 -200
  140. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +855 -0
  141. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +244 -0
  142. data/vendor/faiss/faiss/impl/NNDescent.cpp +487 -0
  143. data/vendor/faiss/faiss/impl/NNDescent.h +154 -0
  144. data/vendor/faiss/faiss/impl/NSG.cpp +679 -0
  145. data/vendor/faiss/faiss/impl/NSG.h +199 -0
  146. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +484 -454
  147. data/vendor/faiss/faiss/impl/PolysemousTraining.h +52 -55
  148. data/vendor/faiss/faiss/impl/ProductQuantizer-inl.h +26 -47
  149. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +469 -459
  150. data/vendor/faiss/faiss/impl/ProductQuantizer.h +76 -87
  151. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +758 -0
  152. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +188 -0
  153. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -132
  154. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +647 -707
  155. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +48 -46
  156. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +129 -131
  157. data/vendor/faiss/faiss/impl/ThreadedIndex.h +61 -55
  158. data/vendor/faiss/faiss/impl/index_read.cpp +631 -480
  159. data/vendor/faiss/faiss/impl/index_write.cpp +547 -407
  160. data/vendor/faiss/faiss/impl/io.cpp +76 -95
  161. data/vendor/faiss/faiss/impl/io.h +31 -41
  162. data/vendor/faiss/faiss/impl/io_macros.h +60 -29
  163. data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
  164. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  165. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +137 -186
  166. data/vendor/faiss/faiss/impl/lattice_Zn.h +40 -51
  167. data/vendor/faiss/faiss/impl/platform_macros.h +29 -8
  168. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +77 -124
  169. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +39 -48
  170. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +41 -52
  171. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +80 -117
  172. data/vendor/faiss/faiss/impl/simd_result_handlers.h +109 -137
  173. data/vendor/faiss/faiss/index_factory.cpp +619 -397
  174. data/vendor/faiss/faiss/index_factory.h +8 -6
  175. data/vendor/faiss/faiss/index_io.h +23 -26
  176. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +67 -75
  177. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +22 -24
  178. data/vendor/faiss/faiss/invlists/DirectMap.cpp +96 -112
  179. data/vendor/faiss/faiss/invlists/DirectMap.h +29 -33
  180. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +307 -364
  181. data/vendor/faiss/faiss/invlists/InvertedLists.h +151 -151
  182. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.cpp +29 -34
  183. data/vendor/faiss/faiss/invlists/InvertedListsIOHook.h +17 -18
  184. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +257 -293
  185. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +50 -45
  186. data/vendor/faiss/faiss/python/python_callbacks.cpp +23 -26
  187. data/vendor/faiss/faiss/python/python_callbacks.h +9 -16
  188. data/vendor/faiss/faiss/utils/AlignedTable.h +79 -44
  189. data/vendor/faiss/faiss/utils/Heap.cpp +40 -48
  190. data/vendor/faiss/faiss/utils/Heap.h +186 -209
  191. data/vendor/faiss/faiss/utils/WorkerThread.cpp +67 -76
  192. data/vendor/faiss/faiss/utils/WorkerThread.h +32 -33
  193. data/vendor/faiss/faiss/utils/distances.cpp +305 -312
  194. data/vendor/faiss/faiss/utils/distances.h +170 -122
  195. data/vendor/faiss/faiss/utils/distances_simd.cpp +498 -508
  196. data/vendor/faiss/faiss/utils/extra_distances-inl.h +117 -0
  197. data/vendor/faiss/faiss/utils/extra_distances.cpp +113 -232
  198. data/vendor/faiss/faiss/utils/extra_distances.h +30 -29
  199. data/vendor/faiss/faiss/utils/hamming-inl.h +260 -209
  200. data/vendor/faiss/faiss/utils/hamming.cpp +375 -469
  201. data/vendor/faiss/faiss/utils/hamming.h +62 -85
  202. data/vendor/faiss/faiss/utils/ordered_key_value.h +16 -18
  203. data/vendor/faiss/faiss/utils/partitioning.cpp +393 -318
  204. data/vendor/faiss/faiss/utils/partitioning.h +26 -21
  205. data/vendor/faiss/faiss/utils/quantize_lut.cpp +78 -66
  206. data/vendor/faiss/faiss/utils/quantize_lut.h +22 -20
  207. data/vendor/faiss/faiss/utils/random.cpp +39 -63
  208. data/vendor/faiss/faiss/utils/random.h +13 -16
  209. data/vendor/faiss/faiss/utils/simdlib.h +4 -2
  210. data/vendor/faiss/faiss/utils/simdlib_avx2.h +88 -85
  211. data/vendor/faiss/faiss/utils/simdlib_emulated.h +226 -165
  212. data/vendor/faiss/faiss/utils/simdlib_neon.h +832 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +304 -287
  214. data/vendor/faiss/faiss/utils/utils.h +54 -49
  215. metadata +29 -4
@@ -9,18 +9,19 @@
9
9
 
10
10
  #include <faiss/impl/ScalarQuantizer.h>
11
11
 
12
- #include <cstdio>
13
12
  #include <algorithm>
13
+ #include <cstdio>
14
14
 
15
+ #include <faiss/impl/platform_macros.h>
15
16
  #include <omp.h>
16
17
 
17
18
  #ifdef __SSE__
18
19
  #include <immintrin.h>
19
20
  #endif
20
21
 
21
- #include <faiss/utils/utils.h>
22
- #include <faiss/impl/FaissAssert.h>
23
22
  #include <faiss/impl/AuxIndexStructures.h>
23
+ #include <faiss/impl/FaissAssert.h>
24
+ #include <faiss/utils/utils.h>
24
25
 
25
26
  namespace faiss {
26
27
 
@@ -43,11 +44,11 @@ namespace faiss {
43
44
  #ifdef __F16C__
44
45
  #define USE_F16C
45
46
  #else
46
- #warning "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well"
47
+ #warning \
48
+ "Cannot enable AVX optimizations in scalar quantizer if -mf16c is not set as well"
47
49
  #endif
48
50
  #endif
49
51
 
50
-
51
52
  namespace {
52
53
 
53
54
  typedef Index::idx_t idx_t;
@@ -55,7 +56,6 @@ typedef ScalarQuantizer::QuantizerType QuantizerType;
55
56
  typedef ScalarQuantizer::RangeStat RangeStat;
56
57
  using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
57
58
 
58
-
59
59
  /*******************************************************************
60
60
  * Codec: converts between values in [0, 1] and an index in a code
61
61
  * array. The "i" parameter is the vector component index (not byte
@@ -63,108 +63,103 @@ using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
63
63
  */
64
64
 
65
65
  struct Codec8bit {
66
-
67
- static void encode_component (float x, uint8_t *code, int i) {
66
+ static void encode_component(float x, uint8_t* code, int i) {
68
67
  code[i] = (int)(255 * x);
69
68
  }
70
69
 
71
- static float decode_component (const uint8_t *code, int i) {
70
+ static float decode_component(const uint8_t* code, int i) {
72
71
  return (code[i] + 0.5f) / 255.0f;
73
72
  }
74
73
 
75
74
  #ifdef __AVX2__
76
- static __m256 decode_8_components (const uint8_t *code, int i) {
75
+ static __m256 decode_8_components(const uint8_t* code, int i) {
77
76
  uint64_t c8 = *(uint64_t*)(code + i);
78
- __m128i c4lo = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8));
79
- __m128i c4hi = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8 >> 32));
77
+ __m128i c4lo = _mm_cvtepu8_epi32(_mm_set1_epi32(c8));
78
+ __m128i c4hi = _mm_cvtepu8_epi32(_mm_set1_epi32(c8 >> 32));
80
79
  // __m256i i8 = _mm256_set_m128i(c4lo, c4hi);
81
- __m256i i8 = _mm256_castsi128_si256 (c4lo);
82
- i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
83
- __m256 f8 = _mm256_cvtepi32_ps (i8);
84
- __m256 half = _mm256_set1_ps (0.5f);
85
- f8 += half;
86
- __m256 one_255 = _mm256_set1_ps (1.f / 255.f);
87
- return f8 * one_255;
80
+ __m256i i8 = _mm256_castsi128_si256(c4lo);
81
+ i8 = _mm256_insertf128_si256(i8, c4hi, 1);
82
+ __m256 f8 = _mm256_cvtepi32_ps(i8);
83
+ __m256 half = _mm256_set1_ps(0.5f);
84
+ f8 = _mm256_add_ps(f8, half);
85
+ __m256 one_255 = _mm256_set1_ps(1.f / 255.f);
86
+ return _mm256_mul_ps(f8, one_255);
88
87
  }
89
88
  #endif
90
89
  };
91
90
 
92
-
93
91
  struct Codec4bit {
94
-
95
- static void encode_component (float x, uint8_t *code, int i) {
96
- code [i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
92
+ static void encode_component(float x, uint8_t* code, int i) {
93
+ code[i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
97
94
  }
98
95
 
99
- static float decode_component (const uint8_t *code, int i) {
96
+ static float decode_component(const uint8_t* code, int i) {
100
97
  return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
101
98
  }
102
99
 
103
-
104
100
  #ifdef __AVX2__
105
- static __m256 decode_8_components (const uint8_t *code, int i) {
101
+ static __m256 decode_8_components(const uint8_t* code, int i) {
106
102
  uint32_t c4 = *(uint32_t*)(code + (i >> 1));
107
103
  uint32_t mask = 0x0f0f0f0f;
108
104
  uint32_t c4ev = c4 & mask;
109
105
  uint32_t c4od = (c4 >> 4) & mask;
110
106
 
111
107
  // the 8 lower bytes of c8 contain the values
112
- __m128i c8 = _mm_unpacklo_epi8 (_mm_set1_epi32(c4ev),
113
- _mm_set1_epi32(c4od));
114
- __m128i c4lo = _mm_cvtepu8_epi32 (c8);
115
- __m128i c4hi = _mm_cvtepu8_epi32 (_mm_srli_si128(c8, 4));
116
- __m256i i8 = _mm256_castsi128_si256 (c4lo);
117
- i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
118
- __m256 f8 = _mm256_cvtepi32_ps (i8);
119
- __m256 half = _mm256_set1_ps (0.5f);
120
- f8 += half;
121
- __m256 one_255 = _mm256_set1_ps (1.f / 15.f);
122
- return f8 * one_255;
108
+ __m128i c8 =
109
+ _mm_unpacklo_epi8(_mm_set1_epi32(c4ev), _mm_set1_epi32(c4od));
110
+ __m128i c4lo = _mm_cvtepu8_epi32(c8);
111
+ __m128i c4hi = _mm_cvtepu8_epi32(_mm_srli_si128(c8, 4));
112
+ __m256i i8 = _mm256_castsi128_si256(c4lo);
113
+ i8 = _mm256_insertf128_si256(i8, c4hi, 1);
114
+ __m256 f8 = _mm256_cvtepi32_ps(i8);
115
+ __m256 half = _mm256_set1_ps(0.5f);
116
+ f8 = _mm256_add_ps(f8, half);
117
+ __m256 one_255 = _mm256_set1_ps(1.f / 15.f);
118
+ return _mm256_mul_ps(f8, one_255);
123
119
  }
124
120
  #endif
125
121
  };
126
122
 
127
123
  struct Codec6bit {
128
-
129
- static void encode_component (float x, uint8_t *code, int i) {
124
+ static void encode_component(float x, uint8_t* code, int i) {
130
125
  int bits = (int)(x * 63.0);
131
126
  code += (i >> 2) * 3;
132
- switch(i & 3) {
133
- case 0:
134
- code[0] |= bits;
135
- break;
136
- case 1:
137
- code[0] |= bits << 6;
138
- code[1] |= bits >> 2;
139
- break;
140
- case 2:
141
- code[1] |= bits << 4;
142
- code[2] |= bits >> 4;
143
- break;
144
- case 3:
145
- code[2] |= bits << 2;
146
- break;
127
+ switch (i & 3) {
128
+ case 0:
129
+ code[0] |= bits;
130
+ break;
131
+ case 1:
132
+ code[0] |= bits << 6;
133
+ code[1] |= bits >> 2;
134
+ break;
135
+ case 2:
136
+ code[1] |= bits << 4;
137
+ code[2] |= bits >> 4;
138
+ break;
139
+ case 3:
140
+ code[2] |= bits << 2;
141
+ break;
147
142
  }
148
143
  }
149
144
 
150
- static float decode_component (const uint8_t *code, int i) {
145
+ static float decode_component(const uint8_t* code, int i) {
151
146
  uint8_t bits;
152
147
  code += (i >> 2) * 3;
153
- switch(i & 3) {
154
- case 0:
155
- bits = code[0] & 0x3f;
156
- break;
157
- case 1:
158
- bits = code[0] >> 6;
159
- bits |= (code[1] & 0xf) << 2;
160
- break;
161
- case 2:
162
- bits = code[1] >> 4;
163
- bits |= (code[2] & 3) << 4;
164
- break;
165
- case 3:
166
- bits = code[2] >> 2;
167
- break;
148
+ switch (i & 3) {
149
+ case 0:
150
+ bits = code[0] & 0x3f;
151
+ break;
152
+ case 1:
153
+ bits = code[0] >> 6;
154
+ bits |= (code[1] & 0xf) << 2;
155
+ break;
156
+ case 2:
157
+ bits = code[1] >> 4;
158
+ bits |= (code[2] & 3) << 4;
159
+ break;
160
+ case 3:
161
+ bits = code[2] >> 2;
162
+ break;
168
163
  }
169
164
  return (bits + 0.5f) / 63.0f;
170
165
  }
@@ -173,12 +168,14 @@ struct Codec6bit {
173
168
 
174
169
  /* Load 6 bytes that represent 8 6-bit values, return them as a
175
170
  * 8*32 bit vector register */
176
- static __m256i load6 (const uint16_t *code16) {
177
- const __m128i perm = _mm_set_epi8(-1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
171
+ static __m256i load6(const uint16_t* code16) {
172
+ const __m128i perm = _mm_set_epi8(
173
+ -1, 5, 5, 4, 4, 3, -1, 3, -1, 2, 2, 1, 1, 0, -1, 0);
178
174
  const __m256i shifts = _mm256_set_epi32(2, 4, 6, 0, 2, 4, 6, 0);
179
175
 
180
176
  // load 6 bytes
181
- __m128i c1 = _mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]);
177
+ __m128i c1 =
178
+ _mm_set_epi16(0, 0, 0, 0, 0, code16[2], code16[1], code16[0]);
182
179
 
183
180
  // put in 8 * 32 bits
184
181
  __m128i c2 = _mm_shuffle_epi8(c1, perm);
@@ -190,37 +187,33 @@ struct Codec6bit {
190
187
  return c5;
191
188
  }
192
189
 
193
- static __m256 decode_8_components (const uint8_t *code, int i) {
194
- __m256i i8 = load6 ((const uint16_t *)(code + (i >> 2) * 3));
195
- __m256 f8 = _mm256_cvtepi32_ps (i8);
190
+ static __m256 decode_8_components(const uint8_t* code, int i) {
191
+ __m256i i8 = load6((const uint16_t*)(code + (i >> 2) * 3));
192
+ __m256 f8 = _mm256_cvtepi32_ps(i8);
196
193
  // this could also be done with bit manipulations but it is
197
194
  // not obviously faster
198
- __m256 half = _mm256_set1_ps (0.5f);
199
- f8 += half;
200
- __m256 one_63 = _mm256_set1_ps (1.f / 63.f);
201
- return f8 * one_63;
195
+ __m256 half = _mm256_set1_ps(0.5f);
196
+ f8 = _mm256_add_ps(f8, half);
197
+ __m256 one_63 = _mm256_set1_ps(1.f / 63.f);
198
+ return _mm256_mul_ps(f8, one_63);
202
199
  }
203
200
 
204
201
  #endif
205
202
  };
206
203
 
207
-
208
-
209
204
  #ifdef USE_F16C
210
205
 
211
-
212
- uint16_t encode_fp16 (float x) {
213
- __m128 xf = _mm_set1_ps (x);
214
- __m128i xi = _mm_cvtps_ph (
215
- xf, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
216
- return _mm_cvtsi128_si32 (xi) & 0xffff;
206
+ uint16_t encode_fp16(float x) {
207
+ __m128 xf = _mm_set1_ps(x);
208
+ __m128i xi =
209
+ _mm_cvtps_ph(xf, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC);
210
+ return _mm_cvtsi128_si32(xi) & 0xffff;
217
211
  }
218
212
 
219
-
220
- float decode_fp16 (uint16_t x) {
221
- __m128i xi = _mm_set1_epi16 (x);
222
- __m128 xf = _mm_cvtph_ps (xi);
223
- return _mm_cvtss_f32 (xf);
213
+ float decode_fp16(uint16_t x) {
214
+ __m128i xi = _mm_set1_epi16(x);
215
+ __m128 xf = _mm_cvtph_ps(xi);
216
+ return _mm_cvtss_f32(xf);
224
217
  }
225
218
 
226
219
  #else
@@ -228,19 +221,17 @@ float decode_fp16 (uint16_t x) {
228
221
  // non-intrinsic FP16 <-> FP32 code adapted from
229
222
  // https://github.com/ispc/ispc/blob/master/stdlib.ispc
230
223
 
231
- float floatbits (uint32_t x) {
232
- void *xptr = &x;
224
+ float floatbits(uint32_t x) {
225
+ void* xptr = &x;
233
226
  return *(float*)xptr;
234
227
  }
235
228
 
236
- uint32_t intbits (float f) {
237
- void *fptr = &f;
229
+ uint32_t intbits(float f) {
230
+ void* fptr = &f;
238
231
  return *(uint32_t*)fptr;
239
232
  }
240
233
 
241
-
242
- uint16_t encode_fp16 (float f) {
243
-
234
+ uint16_t encode_fp16(float f) {
244
235
  // via Fabian "ryg" Giesen.
245
236
  // https://gist.github.com/2156668
246
237
  uint32_t sign_mask = 0x80000000u;
@@ -297,20 +288,19 @@ uint16_t encode_fp16 (float f) {
297
288
  return (o | (sign >> 16));
298
289
  }
299
290
 
300
- float decode_fp16 (uint16_t h) {
301
-
291
+ float decode_fp16(uint16_t h) {
302
292
  // https://gist.github.com/2144712
303
293
  // Fabian "ryg" Giesen.
304
294
 
305
295
  const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
306
296
 
307
- int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
308
- int32_t exp = shifted_exp & o; // just the exponent
309
- o += (int32_t)(127 - 15) << 23; // exponent adjust
297
+ int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
298
+ int32_t exp = shifted_exp & o; // just the exponent
299
+ o += (int32_t)(127 - 15) << 23; // exponent adjust
310
300
 
311
301
  int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
312
- int32_t zerodenorm_val = intbits(
313
- floatbits(o + (1u<<23)) - floatbits(113u << 23));
302
+ int32_t zerodenorm_val =
303
+ intbits(floatbits(o + (1u << 23)) - floatbits(113u << 23));
314
304
  int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
315
305
 
316
306
  int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
@@ -319,30 +309,21 @@ float decode_fp16 (uint16_t h) {
319
309
 
320
310
  #endif
321
311
 
322
-
323
-
324
312
  /*******************************************************************
325
313
  * Quantizer: normalizes scalar vector components, then passes them
326
314
  * through a codec
327
315
  *******************************************************************/
328
316
 
329
-
330
-
331
-
332
-
333
- template<class Codec, bool uniform, int SIMD>
317
+ template <class Codec, bool uniform, int SIMD>
334
318
  struct QuantizerTemplate {};
335
319
 
336
-
337
- template<class Codec>
338
- struct QuantizerTemplate<Codec, true, 1>: ScalarQuantizer::Quantizer {
320
+ template <class Codec>
321
+ struct QuantizerTemplate<Codec, true, 1> : ScalarQuantizer::Quantizer {
339
322
  const size_t d;
340
323
  const float vmin, vdiff;
341
324
 
342
- QuantizerTemplate(size_t d, const std::vector<float> &trained):
343
- d(d), vmin(trained[0]), vdiff(trained[1])
344
- {
345
- }
325
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
326
+ : d(d), vmin(trained[0]), vdiff(trained[1]) {}
346
327
 
347
328
  void encode_vector(const float* x, uint8_t* code) const final {
348
329
  for (size_t i = 0; i < d; i++) {
@@ -367,43 +348,36 @@ struct QuantizerTemplate<Codec, true, 1>: ScalarQuantizer::Quantizer {
367
348
  }
368
349
  }
369
350
 
370
- float reconstruct_component (const uint8_t * code, int i) const
371
- {
372
- float xi = Codec::decode_component (code, i);
351
+ float reconstruct_component(const uint8_t* code, int i) const {
352
+ float xi = Codec::decode_component(code, i);
373
353
  return vmin + xi * vdiff;
374
354
  }
375
-
376
355
  };
377
356
 
378
-
379
-
380
357
  #ifdef __AVX2__
381
358
 
382
- template<class Codec>
383
- struct QuantizerTemplate<Codec, true, 8>: QuantizerTemplate<Codec, true, 1> {
384
-
385
- QuantizerTemplate (size_t d, const std::vector<float> &trained):
386
- QuantizerTemplate<Codec, true, 1> (d, trained) {}
359
+ template <class Codec>
360
+ struct QuantizerTemplate<Codec, true, 8> : QuantizerTemplate<Codec, true, 1> {
361
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
362
+ : QuantizerTemplate<Codec, true, 1>(d, trained) {}
387
363
 
388
- __m256 reconstruct_8_components (const uint8_t * code, int i) const
389
- {
390
- __m256 xi = Codec::decode_8_components (code, i);
391
- return _mm256_set1_ps(this->vmin) + xi * _mm256_set1_ps (this->vdiff);
364
+ __m256 reconstruct_8_components(const uint8_t* code, int i) const {
365
+ __m256 xi = Codec::decode_8_components(code, i);
366
+ return _mm256_add_ps(
367
+ _mm256_set1_ps(this->vmin),
368
+ _mm256_mul_ps(xi, _mm256_set1_ps(this->vdiff)));
392
369
  }
393
-
394
370
  };
395
371
 
396
372
  #endif
397
373
 
398
-
399
-
400
- template<class Codec>
401
- struct QuantizerTemplate<Codec, false, 1>: ScalarQuantizer::Quantizer {
374
+ template <class Codec>
375
+ struct QuantizerTemplate<Codec, false, 1> : ScalarQuantizer::Quantizer {
402
376
  const size_t d;
403
377
  const float *vmin, *vdiff;
404
378
 
405
- QuantizerTemplate (size_t d, const std::vector<float> &trained):
406
- d(d), vmin(trained.data()), vdiff(trained.data() + d) {}
379
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
380
+ : d(d), vmin(trained.data()), vdiff(trained.data() + d) {}
407
381
 
408
382
  void encode_vector(const float* x, uint8_t* code) const final {
409
383
  for (size_t i = 0; i < d; i++) {
@@ -428,30 +402,25 @@ struct QuantizerTemplate<Codec, false, 1>: ScalarQuantizer::Quantizer {
428
402
  }
429
403
  }
430
404
 
431
- float reconstruct_component (const uint8_t * code, int i) const
432
- {
433
- float xi = Codec::decode_component (code, i);
405
+ float reconstruct_component(const uint8_t* code, int i) const {
406
+ float xi = Codec::decode_component(code, i);
434
407
  return vmin[i] + xi * vdiff[i];
435
408
  }
436
-
437
409
  };
438
410
 
439
-
440
411
  #ifdef __AVX2__
441
412
 
442
- template<class Codec>
443
- struct QuantizerTemplate<Codec, false, 8>: QuantizerTemplate<Codec, false, 1> {
444
-
445
- QuantizerTemplate (size_t d, const std::vector<float> &trained):
446
- QuantizerTemplate<Codec, false, 1> (d, trained) {}
413
+ template <class Codec>
414
+ struct QuantizerTemplate<Codec, false, 8> : QuantizerTemplate<Codec, false, 1> {
415
+ QuantizerTemplate(size_t d, const std::vector<float>& trained)
416
+ : QuantizerTemplate<Codec, false, 1>(d, trained) {}
447
417
 
448
- __m256 reconstruct_8_components (const uint8_t * code, int i) const
449
- {
450
- __m256 xi = Codec::decode_8_components (code, i);
451
- return _mm256_loadu_ps (this->vmin + i) + xi * _mm256_loadu_ps (this->vdiff + i);
418
+ __m256 reconstruct_8_components(const uint8_t* code, int i) const {
419
+ __m256 xi = Codec::decode_8_components(code, i);
420
+ return _mm256_add_ps(
421
+ _mm256_loadu_ps(this->vmin + i),
422
+ _mm256_mul_ps(xi, _mm256_loadu_ps(this->vdiff + i)));
452
423
  }
453
-
454
-
455
424
  };
456
425
 
457
426
  #endif
@@ -460,15 +429,14 @@ struct QuantizerTemplate<Codec, false, 8>: QuantizerTemplate<Codec, false, 1> {
460
429
  * FP16 quantizer
461
430
  *******************************************************************/
462
431
 
463
- template<int SIMDWIDTH>
432
+ template <int SIMDWIDTH>
464
433
  struct QuantizerFP16 {};
465
434
 
466
- template<>
467
- struct QuantizerFP16<1>: ScalarQuantizer::Quantizer {
435
+ template <>
436
+ struct QuantizerFP16<1> : ScalarQuantizer::Quantizer {
468
437
  const size_t d;
469
438
 
470
- QuantizerFP16(size_t d, const std::vector<float> & /* unused */):
471
- d(d) {}
439
+ QuantizerFP16(size_t d, const std::vector<float>& /* unused */) : d(d) {}
472
440
 
473
441
  void encode_vector(const float* x, uint8_t* code) const final {
474
442
  for (size_t i = 0; i < d; i++) {
@@ -482,27 +450,22 @@ struct QuantizerFP16<1>: ScalarQuantizer::Quantizer {
482
450
  }
483
451
  }
484
452
 
485
- float reconstruct_component (const uint8_t * code, int i) const
486
- {
453
+ float reconstruct_component(const uint8_t* code, int i) const {
487
454
  return decode_fp16(((uint16_t*)code)[i]);
488
455
  }
489
-
490
456
  };
491
457
 
492
458
  #ifdef USE_F16C
493
459
 
494
- template<>
495
- struct QuantizerFP16<8>: QuantizerFP16<1> {
460
+ template <>
461
+ struct QuantizerFP16<8> : QuantizerFP16<1> {
462
+ QuantizerFP16(size_t d, const std::vector<float>& trained)
463
+ : QuantizerFP16<1>(d, trained) {}
496
464
 
497
- QuantizerFP16 (size_t d, const std::vector<float> &trained):
498
- QuantizerFP16<1> (d, trained) {}
499
-
500
- __m256 reconstruct_8_components (const uint8_t * code, int i) const
501
- {
502
- __m128i codei = _mm_loadu_si128 ((const __m128i*)(code + 2 * i));
503
- return _mm256_cvtph_ps (codei);
465
+ __m256 reconstruct_8_components(const uint8_t* code, int i) const {
466
+ __m128i codei = _mm_loadu_si128((const __m128i*)(code + 2 * i));
467
+ return _mm256_cvtph_ps(codei);
504
468
  }
505
-
506
469
  };
507
470
 
508
471
  #endif
@@ -511,16 +474,15 @@ struct QuantizerFP16<8>: QuantizerFP16<1> {
511
474
  * 8bit_direct quantizer
512
475
  *******************************************************************/
513
476
 
514
- template<int SIMDWIDTH>
477
+ template <int SIMDWIDTH>
515
478
  struct Quantizer8bitDirect {};
516
479
 
517
- template<>
518
- struct Quantizer8bitDirect<1>: ScalarQuantizer::Quantizer {
480
+ template <>
481
+ struct Quantizer8bitDirect<1> : ScalarQuantizer::Quantizer {
519
482
  const size_t d;
520
483
 
521
- Quantizer8bitDirect(size_t d, const std::vector<float> & /* unused */):
522
- d(d) {}
523
-
484
+ Quantizer8bitDirect(size_t d, const std::vector<float>& /* unused */)
485
+ : d(d) {}
524
486
 
525
487
  void encode_vector(const float* x, uint8_t* code) const final {
526
488
  for (size_t i = 0; i < d; i++) {
@@ -534,82 +496,83 @@ struct Quantizer8bitDirect<1>: ScalarQuantizer::Quantizer {
534
496
  }
535
497
  }
536
498
 
537
- float reconstruct_component (const uint8_t * code, int i) const
538
- {
499
+ float reconstruct_component(const uint8_t* code, int i) const {
539
500
  return code[i];
540
501
  }
541
-
542
502
  };
543
503
 
544
504
  #ifdef __AVX2__
545
505
 
546
- template<>
547
- struct Quantizer8bitDirect<8>: Quantizer8bitDirect<1> {
548
-
549
- Quantizer8bitDirect (size_t d, const std::vector<float> &trained):
550
- Quantizer8bitDirect<1> (d, trained) {}
506
+ template <>
507
+ struct Quantizer8bitDirect<8> : Quantizer8bitDirect<1> {
508
+ Quantizer8bitDirect(size_t d, const std::vector<float>& trained)
509
+ : Quantizer8bitDirect<1>(d, trained) {}
551
510
 
552
- __m256 reconstruct_8_components (const uint8_t * code, int i) const
553
- {
511
+ __m256 reconstruct_8_components(const uint8_t* code, int i) const {
554
512
  __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
555
- __m256i y8 = _mm256_cvtepu8_epi32 (x8); // 8 * int32
556
- return _mm256_cvtepi32_ps (y8); // 8 * float32
513
+ __m256i y8 = _mm256_cvtepu8_epi32(x8); // 8 * int32
514
+ return _mm256_cvtepi32_ps(y8); // 8 * float32
557
515
  }
558
-
559
516
  };
560
517
 
561
518
  #endif
562
519
 
563
-
564
- template<int SIMDWIDTH>
565
- ScalarQuantizer::Quantizer *select_quantizer_1 (
566
- QuantizerType qtype,
567
- size_t d, const std::vector<float> & trained)
568
- {
569
- switch(qtype) {
570
- case ScalarQuantizer::QT_8bit:
571
- return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(d, trained);
572
- case ScalarQuantizer::QT_6bit:
573
- return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(d, trained);
574
- case ScalarQuantizer::QT_4bit:
575
- return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(d, trained);
576
- case ScalarQuantizer::QT_8bit_uniform:
577
- return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(d, trained);
578
- case ScalarQuantizer::QT_4bit_uniform:
579
- return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(d, trained);
580
- case ScalarQuantizer::QT_fp16:
581
- return new QuantizerFP16<SIMDWIDTH> (d, trained);
582
- case ScalarQuantizer::QT_8bit_direct:
583
- return new Quantizer8bitDirect<SIMDWIDTH> (d, trained);
584
- }
585
- FAISS_THROW_MSG ("unknown qtype");
520
+ template <int SIMDWIDTH>
521
+ ScalarQuantizer::Quantizer* select_quantizer_1(
522
+ QuantizerType qtype,
523
+ size_t d,
524
+ const std::vector<float>& trained) {
525
+ switch (qtype) {
526
+ case ScalarQuantizer::QT_8bit:
527
+ return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(
528
+ d, trained);
529
+ case ScalarQuantizer::QT_6bit:
530
+ return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(
531
+ d, trained);
532
+ case ScalarQuantizer::QT_4bit:
533
+ return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(
534
+ d, trained);
535
+ case ScalarQuantizer::QT_8bit_uniform:
536
+ return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(
537
+ d, trained);
538
+ case ScalarQuantizer::QT_4bit_uniform:
539
+ return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(
540
+ d, trained);
541
+ case ScalarQuantizer::QT_fp16:
542
+ return new QuantizerFP16<SIMDWIDTH>(d, trained);
543
+ case ScalarQuantizer::QT_8bit_direct:
544
+ return new Quantizer8bitDirect<SIMDWIDTH>(d, trained);
545
+ }
546
+ FAISS_THROW_MSG("unknown qtype");
586
547
  }
587
548
 
588
-
589
-
590
-
591
549
  /*******************************************************************
592
550
  * Quantizer range training
593
551
  */
594
552
 
595
- static float sqr (float x) {
553
+ static float sqr(float x) {
596
554
  return x * x;
597
555
  }
598
556
 
599
-
600
- void train_Uniform(RangeStat rs, float rs_arg,
601
- idx_t n, int k, const float *x,
602
- std::vector<float> & trained)
603
- {
604
- trained.resize (2);
605
- float & vmin = trained[0];
606
- float & vmax = trained[1];
557
+ void train_Uniform(
558
+ RangeStat rs,
559
+ float rs_arg,
560
+ idx_t n,
561
+ int k,
562
+ const float* x,
563
+ std::vector<float>& trained) {
564
+ trained.resize(2);
565
+ float& vmin = trained[0];
566
+ float& vmax = trained[1];
607
567
 
608
568
  if (rs == ScalarQuantizer::RS_minmax) {
609
- vmin = HUGE_VAL; vmax = -HUGE_VAL;
569
+ vmin = HUGE_VAL;
570
+ vmax = -HUGE_VAL;
610
571
  for (size_t i = 0; i < n; i++) {
611
- if (x[i] < vmin) vmin = x[i];
612
- if (x[i] > vmax) vmax = x[i];
572
+ if (x[i] < vmin)
573
+ vmin = x[i];
574
+ if (x[i] > vmax)
575
+ vmax = x[i];
613
576
  }
614
577
  float vexp = (vmax - vmin) * rs_arg;
615
578
  vmin -= vexp;
@@ -624,16 +587,18 @@ void train_Uniform(RangeStat rs, float rs_arg,
624
587
  float var = sum2 / n - mean * mean;
625
588
  float std = var <= 0 ? 1.0 : sqrt(var);
626
589
 
627
- vmin = mean - std * rs_arg ;
628
- vmax = mean + std * rs_arg ;
590
+ vmin = mean - std * rs_arg;
591
+ vmax = mean + std * rs_arg;
629
592
  } else if (rs == ScalarQuantizer::RS_quantiles) {
630
593
  std::vector<float> x_copy(n);
631
594
  memcpy(x_copy.data(), x, n * sizeof(*x));
632
595
  // TODO just do a qucikselect
633
596
  std::sort(x_copy.begin(), x_copy.end());
634
597
  int o = int(rs_arg * n);
635
- if (o < 0) o = 0;
636
- if (o > n - o) o = n / 2;
598
+ if (o < 0)
599
+ o = 0;
600
+ if (o > n - o)
601
+ o = n / 2;
637
602
  vmin = x_copy[o];
638
603
  vmax = x_copy[n - 1 - o];
639
604
 
@@ -643,8 +608,10 @@ void train_Uniform(RangeStat rs, float rs_arg,
643
608
  {
644
609
  vmin = HUGE_VAL, vmax = -HUGE_VAL;
645
610
  for (size_t i = 0; i < n; i++) {
646
- if (x[i] < vmin) vmin = x[i];
647
- if (x[i] > vmax) vmax = x[i];
611
+ if (x[i] < vmin)
612
+ vmin = x[i];
613
+ if (x[i] > vmax)
614
+ vmax = x[i];
648
615
  sx += x[i];
649
616
  }
650
617
  b = vmin;
@@ -659,62 +626,71 @@ void train_Uniform(RangeStat rs, float rs_arg,
659
626
 
660
627
  for (idx_t i = 0; i < n; i++) {
661
628
  float xi = x[i];
662
- float ni = floor ((xi - b) / a + 0.5);
663
- if (ni < 0) ni = 0;
664
- if (ni >= k) ni = k - 1;
665
- err1 += sqr (xi - (ni * a + b));
666
- sn += ni;
629
+ float ni = floor((xi - b) / a + 0.5);
630
+ if (ni < 0)
631
+ ni = 0;
632
+ if (ni >= k)
633
+ ni = k - 1;
634
+ err1 += sqr(xi - (ni * a + b));
635
+ sn += ni;
667
636
  sn2 += ni * ni;
668
637
  sxn += ni * xi;
669
638
  }
670
639
 
671
640
  if (err1 == last_err) {
672
- iter_last_err ++;
673
- if (iter_last_err == 16) break;
641
+ iter_last_err++;
642
+ if (iter_last_err == 16)
643
+ break;
674
644
  } else {
675
645
  last_err = err1;
676
646
  iter_last_err = 0;
677
647
  }
678
648
 
679
- float det = sqr (sn) - sn2 * n;
649
+ float det = sqr(sn) - sn2 * n;
680
650
 
681
651
  b = (sn * sxn - sn2 * sx) / det;
682
652
  a = (sn * sx - n * sxn) / det;
683
653
  if (verbose) {
684
- printf ("it %d, err1=%g \r", it, err1);
654
+ printf("it %d, err1=%g \r", it, err1);
685
655
  fflush(stdout);
686
656
  }
687
657
  }
688
- if (verbose) printf("\n");
658
+ if (verbose)
659
+ printf("\n");
689
660
 
690
661
  vmin = b;
691
662
  vmax = b + a * (k - 1);
692
663
 
693
664
  } else {
694
- FAISS_THROW_MSG ("Invalid qtype");
665
+ FAISS_THROW_MSG("Invalid qtype");
695
666
  }
696
667
  vmax -= vmin;
697
668
  }
698
669
 
699
- void train_NonUniform(RangeStat rs, float rs_arg,
700
- idx_t n, int d, int k, const float *x,
701
- std::vector<float> & trained)
702
- {
703
-
704
- trained.resize (2 * d);
705
- float * vmin = trained.data();
706
- float * vmax = trained.data() + d;
670
+ void train_NonUniform(
671
+ RangeStat rs,
672
+ float rs_arg,
673
+ idx_t n,
674
+ int d,
675
+ int k,
676
+ const float* x,
677
+ std::vector<float>& trained) {
678
+ trained.resize(2 * d);
679
+ float* vmin = trained.data();
680
+ float* vmax = trained.data() + d;
707
681
  if (rs == ScalarQuantizer::RS_minmax) {
708
- memcpy (vmin, x, sizeof(*x) * d);
709
- memcpy (vmax, x, sizeof(*x) * d);
682
+ memcpy(vmin, x, sizeof(*x) * d);
683
+ memcpy(vmax, x, sizeof(*x) * d);
710
684
  for (size_t i = 1; i < n; i++) {
711
- const float *xi = x + i * d;
685
+ const float* xi = x + i * d;
712
686
  for (size_t j = 0; j < d; j++) {
713
- if (xi[j] < vmin[j]) vmin[j] = xi[j];
714
- if (xi[j] > vmax[j]) vmax[j] = xi[j];
687
+ if (xi[j] < vmin[j])
688
+ vmin[j] = xi[j];
689
+ if (xi[j] > vmax[j])
690
+ vmax[j] = xi[j];
715
691
  }
716
692
  }
717
- float *vdiff = vmax;
693
+ float* vdiff = vmax;
718
694
  for (size_t j = 0; j < d; j++) {
719
695
  float vexp = (vmax[j] - vmin[j]) * rs_arg;
720
696
  vmin[j] -= vexp;
@@ -725,7 +701,7 @@ void train_NonUniform(RangeStat rs, float rs_arg,
725
701
  // transpose
726
702
  std::vector<float> xt(n * d);
727
703
  for (size_t i = 1; i < n; i++) {
728
- const float *xi = x + i * d;
704
+ const float* xi = x + i * d;
729
705
  for (size_t j = 0; j < d; j++) {
730
706
  xt[j * n + i] = xi[j];
731
707
  }
@@ -733,108 +709,98 @@ void train_NonUniform(RangeStat rs, float rs_arg,
733
709
  std::vector<float> trained_d(2);
734
710
  #pragma omp parallel for
735
711
  for (int j = 0; j < d; j++) {
736
- train_Uniform(rs, rs_arg,
737
- n, k, xt.data() + j * n,
738
- trained_d);
712
+ train_Uniform(rs, rs_arg, n, k, xt.data() + j * n, trained_d);
739
713
  vmin[j] = trained_d[0];
740
714
  vmax[j] = trained_d[1];
741
715
  }
742
716
  }
743
717
  }
744
718
 
745
-
746
-
747
719
  /*******************************************************************
748
720
  * Similarity: gets vector components and computes a similarity wrt. a
749
721
  * query vector stored in the object. The data fields just encapsulate
750
722
  * an accumulator.
751
723
  */
752
724
 
753
- template<int SIMDWIDTH>
725
+ template <int SIMDWIDTH>
754
726
  struct SimilarityL2 {};
755
727
 
756
-
757
- template<>
728
+ template <>
758
729
  struct SimilarityL2<1> {
759
730
  static constexpr int simdwidth = 1;
760
731
  static constexpr MetricType metric_type = METRIC_L2;
761
732
 
762
733
  const float *y, *yi;
763
734
 
764
- explicit SimilarityL2 (const float * y): y(y) {}
735
+ explicit SimilarityL2(const float* y) : y(y) {}
765
736
 
766
737
  /******* scalar accumulator *******/
767
738
 
768
739
  float accu;
769
740
 
770
- void begin () {
741
+ void begin() {
771
742
  accu = 0;
772
743
  yi = y;
773
744
  }
774
745
 
775
- void add_component (float x) {
746
+ void add_component(float x) {
776
747
  float tmp = *yi++ - x;
777
748
  accu += tmp * tmp;
778
749
  }
779
750
 
780
- void add_component_2 (float x1, float x2) {
751
+ void add_component_2(float x1, float x2) {
781
752
  float tmp = x1 - x2;
782
753
  accu += tmp * tmp;
783
754
  }
784
755
 
785
- float result () {
756
+ float result() {
786
757
  return accu;
787
758
  }
788
759
  };
789
760
 
790
-
791
761
  #ifdef __AVX2__
792
- template<>
762
+ template <>
793
763
  struct SimilarityL2<8> {
794
764
  static constexpr int simdwidth = 8;
795
765
  static constexpr MetricType metric_type = METRIC_L2;
796
766
 
797
767
  const float *y, *yi;
798
768
 
799
- explicit SimilarityL2 (const float * y): y(y) {}
769
+ explicit SimilarityL2(const float* y) : y(y) {}
800
770
  __m256 accu8;
801
771
 
802
- void begin_8 () {
772
+ void begin_8() {
803
773
  accu8 = _mm256_setzero_ps();
804
774
  yi = y;
805
775
  }
806
776
 
807
- void add_8_components (__m256 x) {
808
- __m256 yiv = _mm256_loadu_ps (yi);
777
+ void add_8_components(__m256 x) {
778
+ __m256 yiv = _mm256_loadu_ps(yi);
809
779
  yi += 8;
810
- __m256 tmp = yiv - x;
811
- accu8 += tmp * tmp;
780
+ __m256 tmp = _mm256_sub_ps(yiv, x);
781
+ accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
812
782
  }
813
783
 
814
- void add_8_components_2 (__m256 x, __m256 y) {
815
- __m256 tmp = y - x;
816
- accu8 += tmp * tmp;
784
+ void add_8_components_2(__m256 x, __m256 y) {
785
+ __m256 tmp = _mm256_sub_ps(y, x);
786
+ accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(tmp, tmp));
817
787
  }
818
788
 
819
- float result_8 () {
789
+ float result_8() {
820
790
  __m256 sum = _mm256_hadd_ps(accu8, accu8);
821
791
  __m256 sum2 = _mm256_hadd_ps(sum, sum);
822
792
  // now add the 0th and 4th component
823
- return
824
- _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
825
- _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
793
+ return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
794
+ _mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
826
795
  }
827
-
828
796
  };
829
797
 
830
798
  #endif
831
799
 
832
-
833
- template<int SIMDWIDTH>
800
+ template <int SIMDWIDTH>
834
801
  struct SimilarityIP {};
835
802
 
836
-
837
- template<>
803
+ template <>
838
804
  struct SimilarityIP<1> {
839
805
  static constexpr int simdwidth = 1;
840
806
  static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
@@ -842,30 +808,29 @@ struct SimilarityIP<1> {
842
808
 
843
809
  float accu;
844
810
 
845
- explicit SimilarityIP (const float * y):
846
- y (y) {}
811
+ explicit SimilarityIP(const float* y) : y(y) {}
847
812
 
848
- void begin () {
813
+ void begin() {
849
814
  accu = 0;
850
815
  yi = y;
851
816
  }
852
817
 
853
- void add_component (float x) {
854
- accu += *yi++ * x;
818
+ void add_component(float x) {
819
+ accu += *yi++ * x;
855
820
  }
856
821
 
857
- void add_component_2 (float x1, float x2) {
858
- accu += x1 * x2;
822
+ void add_component_2(float x1, float x2) {
823
+ accu += x1 * x2;
859
824
  }
860
825
 
861
- float result () {
826
+ float result() {
862
827
  return accu;
863
828
  }
864
829
  };
865
830
 
866
831
  #ifdef __AVX2__
867
832
 
868
- template<>
833
+ template <>
869
834
  struct SimilarityIP<8> {
870
835
  static constexpr int simdwidth = 8;
871
836
  static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
@@ -874,59 +839,53 @@ struct SimilarityIP<8> {
874
839
 
875
840
  float accu;
876
841
 
877
- explicit SimilarityIP (const float * y):
878
- y (y) {}
842
+ explicit SimilarityIP(const float* y) : y(y) {}
879
843
 
880
844
  __m256 accu8;
881
845
 
882
- void begin_8 () {
846
+ void begin_8() {
883
847
  accu8 = _mm256_setzero_ps();
884
848
  yi = y;
885
849
  }
886
850
 
887
- void add_8_components (__m256 x) {
888
- __m256 yiv = _mm256_loadu_ps (yi);
851
+ void add_8_components(__m256 x) {
852
+ __m256 yiv = _mm256_loadu_ps(yi);
889
853
  yi += 8;
890
- accu8 += yiv * x;
854
+ accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(yiv, x));
891
855
  }
892
856
 
893
- void add_8_components_2 (__m256 x1, __m256 x2) {
894
- accu8 += x1 * x2;
857
+ void add_8_components_2(__m256 x1, __m256 x2) {
858
+ accu8 = _mm256_add_ps(accu8, _mm256_mul_ps(x1, x2));
895
859
  }
896
860
 
897
- float result_8 () {
861
+ float result_8() {
898
862
  __m256 sum = _mm256_hadd_ps(accu8, accu8);
899
863
  __m256 sum2 = _mm256_hadd_ps(sum, sum);
900
864
  // now add the 0th and 4th component
901
- return
902
- _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
903
- _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
865
+ return _mm_cvtss_f32(_mm256_castps256_ps128(sum2)) +
866
+ _mm_cvtss_f32(_mm256_extractf128_ps(sum2, 1));
904
867
  }
905
868
  };
906
869
  #endif
907
870
 
908
-
909
871
  /*******************************************************************
910
872
  * DistanceComputer: combines a similarity and a quantizer to do
911
873
  * code-to-vector or code-to-code comparisons
912
874
  *******************************************************************/
913
875
 
914
- template<class Quantizer, class Similarity, int SIMDWIDTH>
876
+ template <class Quantizer, class Similarity, int SIMDWIDTH>
915
877
  struct DCTemplate : SQDistanceComputer {};
916
878
 
917
- template<class Quantizer, class Similarity>
918
- struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
919
- {
879
+ template <class Quantizer, class Similarity>
880
+ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer {
920
881
  using Sim = Similarity;
921
882
 
922
883
  Quantizer quant;
923
884
 
924
- DCTemplate(size_t d, const std::vector<float> &trained):
925
- quant(d, trained)
926
- {}
885
+ DCTemplate(size_t d, const std::vector<float>& trained)
886
+ : quant(d, trained) {}
927
887
 
928
888
  float compute_distance(const float* x, const uint8_t* code) const {
929
-
930
889
  Similarity sim(x);
931
890
  sim.begin();
932
891
  for (size_t i = 0; i < quant.d; i++) {
@@ -937,7 +896,7 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
937
896
  }
938
897
 
939
898
  float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
940
- const {
899
+ const {
941
900
  Similarity sim(nullptr);
942
901
  sim.begin();
943
902
  for (size_t i = 0; i < quant.d; i++) {
@@ -948,41 +907,37 @@ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
948
907
  return sim.result();
949
908
  }
950
909
 
951
- void set_query (const float *x) final {
910
+ void set_query(const float* x) final {
952
911
  q = x;
953
912
  }
954
913
 
955
914
  /// compute distance of vector i to current query
956
- float operator () (idx_t i) final {
957
- return compute_distance (q, codes + i * code_size);
915
+ float operator()(idx_t i) final {
916
+ return query_to_code(codes + i * code_size);
958
917
  }
959
918
 
960
- float symmetric_dis (idx_t i, idx_t j) override {
961
- return compute_code_distance (codes + i * code_size,
962
- codes + j * code_size);
919
+ float symmetric_dis(idx_t i, idx_t j) override {
920
+ return compute_code_distance(
921
+ codes + i * code_size, codes + j * code_size);
963
922
  }
964
923
 
965
- float query_to_code (const uint8_t * code) const {
966
- return compute_distance (q, code);
924
+ float query_to_code(const uint8_t* code) const final {
925
+ return compute_distance(q, code);
967
926
  }
968
-
969
927
  };
970
928
 
971
929
  #ifdef USE_F16C
972
930
 
973
- template<class Quantizer, class Similarity>
974
- struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
975
- {
931
+ template <class Quantizer, class Similarity>
932
+ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer {
976
933
  using Sim = Similarity;
977
934
 
978
935
  Quantizer quant;
979
936
 
980
- DCTemplate(size_t d, const std::vector<float> &trained):
981
- quant(d, trained)
982
- {}
937
+ DCTemplate(size_t d, const std::vector<float>& trained)
938
+ : quant(d, trained) {}
983
939
 
984
940
  float compute_distance(const float* x, const uint8_t* code) const {
985
-
986
941
  Similarity sim(x);
987
942
  sim.begin_8();
988
943
  for (size_t i = 0; i < quant.d; i += 8) {
@@ -993,7 +948,7 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
993
948
  }
994
949
 
995
950
  float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
996
- const {
951
+ const {
997
952
  Similarity sim(nullptr);
998
953
  sim.begin_8();
999
954
  for (size_t i = 0; i < quant.d; i += 8) {
@@ -1004,49 +959,45 @@ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
1004
959
  return sim.result_8();
1005
960
  }
1006
961
 
1007
- void set_query (const float *x) final {
962
+ void set_query(const float* x) final {
1008
963
  q = x;
1009
964
  }
1010
965
 
1011
966
  /// compute distance of vector i to current query
1012
- float operator () (idx_t i) final {
1013
- return compute_distance (q, codes + i * code_size);
967
+ float operator()(idx_t i) final {
968
+ return query_to_code(codes + i * code_size);
1014
969
  }
1015
970
 
1016
- float symmetric_dis (idx_t i, idx_t j) override {
1017
- return compute_code_distance (codes + i * code_size,
1018
- codes + j * code_size);
971
+ float symmetric_dis(idx_t i, idx_t j) override {
972
+ return compute_code_distance(
973
+ codes + i * code_size, codes + j * code_size);
1019
974
  }
1020
975
 
1021
- float query_to_code (const uint8_t * code) const {
1022
- return compute_distance (q, code);
976
+ float query_to_code(const uint8_t* code) const final {
977
+ return compute_distance(q, code);
1023
978
  }
1024
-
1025
979
  };
1026
980
 
1027
981
  #endif
1028
982
 
1029
-
1030
-
1031
983
  /*******************************************************************
1032
984
  * DistanceComputerByte: computes distances in the integer domain
1033
985
  *******************************************************************/
1034
986
 
1035
- template<class Similarity, int SIMDWIDTH>
987
+ template <class Similarity, int SIMDWIDTH>
1036
988
  struct DistanceComputerByte : SQDistanceComputer {};
1037
989
 
1038
- template<class Similarity>
990
+ template <class Similarity>
1039
991
  struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1040
992
  using Sim = Similarity;
1041
993
 
1042
994
  int d;
1043
995
  std::vector<uint8_t> tmp;
1044
996
 
1045
- DistanceComputerByte(int d, const std::vector<float> &): d(d), tmp(d) {
1046
- }
997
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1047
998
 
1048
999
  int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1049
- const {
1000
+ const {
1050
1001
  int accu = 0;
1051
1002
  for (int i = 0; i < d; i++) {
1052
1003
  if (Sim::metric_type == METRIC_INNER_PRODUCT) {
@@ -1059,7 +1010,7 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1059
1010
  return accu;
1060
1011
  }
1061
1012
 
1062
- void set_query (const float *x) final {
1013
+ void set_query(const float* x) final {
1063
1014
  for (int i = 0; i < d; i++) {
1064
1015
  tmp[i] = int(x[i]);
1065
1016
  }
@@ -1071,44 +1022,41 @@ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1071
1022
  }
1072
1023
 
1073
1024
  /// compute distance of vector i to current query
1074
- float operator () (idx_t i) final {
1075
- return compute_distance (q, codes + i * code_size);
1025
+ float operator()(idx_t i) final {
1026
+ return query_to_code(codes + i * code_size);
1076
1027
  }
1077
1028
 
1078
- float symmetric_dis (idx_t i, idx_t j) override {
1079
- return compute_code_distance (codes + i * code_size,
1080
- codes + j * code_size);
1029
+ float symmetric_dis(idx_t i, idx_t j) override {
1030
+ return compute_code_distance(
1031
+ codes + i * code_size, codes + j * code_size);
1081
1032
  }
1082
1033
 
1083
- float query_to_code (const uint8_t * code) const {
1084
- return compute_code_distance (tmp.data(), code);
1034
+ float query_to_code(const uint8_t* code) const final {
1035
+ return compute_code_distance(tmp.data(), code);
1085
1036
  }
1086
-
1087
1037
  };
1088
1038
 
1089
1039
  #ifdef __AVX2__
1090
1040
 
1091
-
1092
- template<class Similarity>
1041
+ template <class Similarity>
1093
1042
  struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1094
1043
  using Sim = Similarity;
1095
1044
 
1096
1045
  int d;
1097
1046
  std::vector<uint8_t> tmp;
1098
1047
 
1099
- DistanceComputerByte(int d, const std::vector<float> &): d(d), tmp(d) {
1100
- }
1048
+ DistanceComputerByte(int d, const std::vector<float>&) : d(d), tmp(d) {}
1101
1049
 
1102
1050
  int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1103
- const {
1051
+ const {
1104
1052
  // __m256i accu = _mm256_setzero_ps ();
1105
- __m256i accu = _mm256_setzero_si256 ();
1053
+ __m256i accu = _mm256_setzero_si256();
1106
1054
  for (int i = 0; i < d; i += 16) {
1107
1055
  // load 16 bytes, convert to 16 uint16_t
1108
- __m256i c1 = _mm256_cvtepu8_epi16
1109
- (_mm_loadu_si128((__m128i*)(code1 + i)));
1110
- __m256i c2 = _mm256_cvtepu8_epi16
1111
- (_mm_loadu_si128((__m128i*)(code2 + i)));
1056
+ __m256i c1 = _mm256_cvtepu8_epi16(
1057
+ _mm_loadu_si128((__m128i*)(code1 + i)));
1058
+ __m256i c2 = _mm256_cvtepu8_epi16(
1059
+ _mm_loadu_si128((__m128i*)(code2 + i)));
1112
1060
  __m256i prod32;
1113
1061
  if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1114
1062
  prod32 = _mm256_madd_epi16(c1, c2);
@@ -1116,17 +1064,16 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1116
1064
  __m256i diff = _mm256_sub_epi16(c1, c2);
1117
1065
  prod32 = _mm256_madd_epi16(diff, diff);
1118
1066
  }
1119
- accu = _mm256_add_epi32 (accu, prod32);
1120
-
1067
+ accu = _mm256_add_epi32(accu, prod32);
1121
1068
  }
1122
1069
  __m128i sum = _mm256_extractf128_si256(accu, 0);
1123
- sum = _mm_add_epi32 (sum, _mm256_extractf128_si256(accu, 1));
1124
- sum = _mm_hadd_epi32 (sum, sum);
1125
- sum = _mm_hadd_epi32 (sum, sum);
1126
- return _mm_cvtsi128_si32 (sum);
1070
+ sum = _mm_add_epi32(sum, _mm256_extractf128_si256(accu, 1));
1071
+ sum = _mm_hadd_epi32(sum, sum);
1072
+ sum = _mm_hadd_epi32(sum, sum);
1073
+ return _mm_cvtsi128_si32(sum);
1127
1074
  }
1128
1075
 
1129
- void set_query (const float *x) final {
1076
+ void set_query(const float* x) final {
1130
1077
  /*
1131
1078
  for (int i = 0; i < d; i += 8) {
1132
1079
  __m256 xi = _mm256_loadu_ps (x + i);
@@ -1143,20 +1090,18 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1143
1090
  }
1144
1091
 
1145
1092
  /// compute distance of vector i to current query
1146
- float operator () (idx_t i) final {
1147
- return compute_distance (q, codes + i * code_size);
1093
+ float operator()(idx_t i) final {
1094
+ return query_to_code(codes + i * code_size);
1148
1095
  }
1149
1096
 
1150
- float symmetric_dis (idx_t i, idx_t j) override {
1151
- return compute_code_distance (codes + i * code_size,
1152
- codes + j * code_size);
1097
+ float symmetric_dis(idx_t i, idx_t j) override {
1098
+ return compute_code_distance(
1099
+ codes + i * code_size, codes + j * code_size);
1153
1100
  }
1154
1101
 
1155
- float query_to_code (const uint8_t * code) const {
1156
- return compute_code_distance (tmp.data(), code);
1102
+ float query_to_code(const uint8_t* code) const final {
1103
+ return compute_code_distance(tmp.data(), code);
1157
1104
  }
1158
-
1159
-
1160
1105
  };
1161
1106
 
1162
1107
  #endif
@@ -1166,215 +1111,218 @@ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1166
1111
  * specialization
1167
1112
  *******************************************************************/
1168
1113
 
1169
-
1170
- template<class Sim>
1171
- SQDistanceComputer *select_distance_computer (
1172
- QuantizerType qtype,
1173
- size_t d, const std::vector<float> & trained)
1174
- {
1114
+ template <class Sim>
1115
+ SQDistanceComputer* select_distance_computer(
1116
+ QuantizerType qtype,
1117
+ size_t d,
1118
+ const std::vector<float>& trained) {
1175
1119
  constexpr int SIMDWIDTH = Sim::simdwidth;
1176
- switch(qtype) {
1177
- case ScalarQuantizer::QT_8bit_uniform:
1178
- return new DCTemplate<QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1179
- Sim, SIMDWIDTH>(d, trained);
1180
-
1181
- case ScalarQuantizer::QT_4bit_uniform:
1182
- return new DCTemplate<QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1183
- Sim, SIMDWIDTH>(d, trained);
1184
-
1185
- case ScalarQuantizer::QT_8bit:
1186
- return new DCTemplate<QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1187
- Sim, SIMDWIDTH>(d, trained);
1188
-
1189
- case ScalarQuantizer::QT_6bit:
1190
- return new DCTemplate<QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1191
- Sim, SIMDWIDTH>(d, trained);
1192
-
1193
- case ScalarQuantizer::QT_4bit:
1194
- return new DCTemplate<QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1195
- Sim, SIMDWIDTH>(d, trained);
1196
-
1197
- case ScalarQuantizer::QT_fp16:
1198
- return new DCTemplate
1199
- <QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained);
1200
-
1201
- case ScalarQuantizer::QT_8bit_direct:
1202
- if (d % 16 == 0) {
1203
- return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1204
- } else {
1205
- return new DCTemplate
1206
- <Quantizer8bitDirect<SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained);
1207
- }
1120
+ switch (qtype) {
1121
+ case ScalarQuantizer::QT_8bit_uniform:
1122
+ return new DCTemplate<
1123
+ QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1124
+ Sim,
1125
+ SIMDWIDTH>(d, trained);
1126
+
1127
+ case ScalarQuantizer::QT_4bit_uniform:
1128
+ return new DCTemplate<
1129
+ QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1130
+ Sim,
1131
+ SIMDWIDTH>(d, trained);
1132
+
1133
+ case ScalarQuantizer::QT_8bit:
1134
+ return new DCTemplate<
1135
+ QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1136
+ Sim,
1137
+ SIMDWIDTH>(d, trained);
1138
+
1139
+ case ScalarQuantizer::QT_6bit:
1140
+ return new DCTemplate<
1141
+ QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1142
+ Sim,
1143
+ SIMDWIDTH>(d, trained);
1144
+
1145
+ case ScalarQuantizer::QT_4bit:
1146
+ return new DCTemplate<
1147
+ QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1148
+ Sim,
1149
+ SIMDWIDTH>(d, trained);
1150
+
1151
+ case ScalarQuantizer::QT_fp16:
1152
+ return new DCTemplate<QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(
1153
+ d, trained);
1154
+
1155
+ case ScalarQuantizer::QT_8bit_direct:
1156
+ if (d % 16 == 0) {
1157
+ return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1158
+ } else {
1159
+ return new DCTemplate<
1160
+ Quantizer8bitDirect<SIMDWIDTH>,
1161
+ Sim,
1162
+ SIMDWIDTH>(d, trained);
1163
+ }
1208
1164
  }
1209
- FAISS_THROW_MSG ("unknown qtype");
1165
+ FAISS_THROW_MSG("unknown qtype");
1210
1166
  return nullptr;
1211
1167
  }
1212
1168
 
1213
-
1214
-
1215
1169
  } // anonymous namespace
1216
1170
 
1217
-
1218
-
1219
1171
  /*******************************************************************
1220
1172
  * ScalarQuantizer implementation
1221
1173
  ********************************************************************/
1222
1174
 
1223
-
1224
-
1225
- ScalarQuantizer::ScalarQuantizer
1226
- (size_t d, QuantizerType qtype):
1227
- qtype (qtype), rangestat(RS_minmax), rangestat_arg(0), d(d)
1228
- {
1229
- set_derived_sizes();
1175
+ ScalarQuantizer::ScalarQuantizer(size_t d, QuantizerType qtype)
1176
+ : qtype(qtype), rangestat(RS_minmax), rangestat_arg(0), d(d) {
1177
+ set_derived_sizes();
1230
1178
  }
1231
1179
 
1232
- ScalarQuantizer::ScalarQuantizer ():
1233
- qtype(QT_8bit),
1234
- rangestat(RS_minmax), rangestat_arg(0), d(0), bits(0), code_size(0)
1235
- {}
1180
+ ScalarQuantizer::ScalarQuantizer()
1181
+ : qtype(QT_8bit),
1182
+ rangestat(RS_minmax),
1183
+ rangestat_arg(0),
1184
+ d(0),
1185
+ bits(0),
1186
+ code_size(0) {}
1236
1187
 
1237
- void ScalarQuantizer::set_derived_sizes ()
1238
- {
1188
+ void ScalarQuantizer::set_derived_sizes() {
1239
1189
  switch (qtype) {
1240
- case QT_8bit:
1241
- case QT_8bit_uniform:
1242
- case QT_8bit_direct:
1243
- code_size = d;
1244
- bits = 8;
1245
- break;
1246
- case QT_4bit:
1247
- case QT_4bit_uniform:
1248
- code_size = (d + 1) / 2;
1249
- bits = 4;
1250
- break;
1251
- case QT_6bit:
1252
- code_size = (d * 6 + 7) / 8;
1253
- bits = 6;
1254
- break;
1255
- case QT_fp16:
1256
- code_size = d * 2;
1257
- bits = 16;
1258
- break;
1190
+ case QT_8bit:
1191
+ case QT_8bit_uniform:
1192
+ case QT_8bit_direct:
1193
+ code_size = d;
1194
+ bits = 8;
1195
+ break;
1196
+ case QT_4bit:
1197
+ case QT_4bit_uniform:
1198
+ code_size = (d + 1) / 2;
1199
+ bits = 4;
1200
+ break;
1201
+ case QT_6bit:
1202
+ code_size = (d * 6 + 7) / 8;
1203
+ bits = 6;
1204
+ break;
1205
+ case QT_fp16:
1206
+ code_size = d * 2;
1207
+ bits = 16;
1208
+ break;
1259
1209
  }
1260
1210
  }
1261
1211
 
1262
- void ScalarQuantizer::train (size_t n, const float *x)
1263
- {
1264
- int bit_per_dim =
1265
- qtype == QT_4bit_uniform ? 4 :
1266
- qtype == QT_4bit ? 4 :
1267
- qtype == QT_6bit ? 6 :
1268
- qtype == QT_8bit_uniform ? 8 :
1269
- qtype == QT_8bit ? 8 : -1;
1212
+ void ScalarQuantizer::train(size_t n, const float* x) {
1213
+ int bit_per_dim = qtype == QT_4bit_uniform ? 4
1214
+ : qtype == QT_4bit ? 4
1215
+ : qtype == QT_6bit ? 6
1216
+ : qtype == QT_8bit_uniform ? 8
1217
+ : qtype == QT_8bit ? 8
1218
+ : -1;
1270
1219
 
1271
1220
  switch (qtype) {
1272
- case QT_4bit_uniform: case QT_8bit_uniform:
1273
- train_Uniform (rangestat, rangestat_arg,
1274
- n * d, 1 << bit_per_dim, x, trained);
1275
- break;
1276
- case QT_4bit: case QT_8bit: case QT_6bit:
1277
- train_NonUniform (rangestat, rangestat_arg,
1278
- n, d, 1 << bit_per_dim, x, trained);
1279
- break;
1280
- case QT_fp16:
1281
- case QT_8bit_direct:
1282
- // no training necessary
1283
- break;
1221
+ case QT_4bit_uniform:
1222
+ case QT_8bit_uniform:
1223
+ train_Uniform(
1224
+ rangestat,
1225
+ rangestat_arg,
1226
+ n * d,
1227
+ 1 << bit_per_dim,
1228
+ x,
1229
+ trained);
1230
+ break;
1231
+ case QT_4bit:
1232
+ case QT_8bit:
1233
+ case QT_6bit:
1234
+ train_NonUniform(
1235
+ rangestat,
1236
+ rangestat_arg,
1237
+ n,
1238
+ d,
1239
+ 1 << bit_per_dim,
1240
+ x,
1241
+ trained);
1242
+ break;
1243
+ case QT_fp16:
1244
+ case QT_8bit_direct:
1245
+ // no training necessary
1246
+ break;
1284
1247
  }
1285
1248
  }
1286
1249
 
1287
- void ScalarQuantizer::train_residual(size_t n,
1288
- const float *x,
1289
- Index *quantizer,
1290
- bool by_residual,
1291
- bool verbose)
1292
- {
1293
- const float * x_in = x;
1250
+ void ScalarQuantizer::train_residual(
1251
+ size_t n,
1252
+ const float* x,
1253
+ Index* quantizer,
1254
+ bool by_residual,
1255
+ bool verbose) {
1256
+ const float* x_in = x;
1294
1257
 
1295
1258
  // 100k points more than enough
1296
- x = fvecs_maybe_subsample (
1297
- d, (size_t*)&n, 100000,
1298
- x, verbose, 1234);
1259
+ x = fvecs_maybe_subsample(d, (size_t*)&n, 100000, x, verbose, 1234);
1299
1260
 
1300
- ScopeDeleter<float> del_x (x_in == x ? nullptr : x);
1261
+ ScopeDeleter<float> del_x(x_in == x ? nullptr : x);
1301
1262
 
1302
1263
  if (by_residual) {
1303
1264
  std::vector<Index::idx_t> idx(n);
1304
- quantizer->assign (n, x, idx.data());
1265
+ quantizer->assign(n, x, idx.data());
1305
1266
 
1306
1267
  std::vector<float> residuals(n * d);
1307
- quantizer->compute_residual_n (n, x, residuals.data(), idx.data());
1268
+ quantizer->compute_residual_n(n, x, residuals.data(), idx.data());
1308
1269
 
1309
- train (n, residuals.data());
1270
+ train(n, residuals.data());
1310
1271
  } else {
1311
- train (n, x);
1272
+ train(n, x);
1312
1273
  }
1313
1274
  }
1314
1275
 
1315
-
1316
- ScalarQuantizer::Quantizer *ScalarQuantizer::select_quantizer () const
1317
- {
1276
+ ScalarQuantizer::Quantizer* ScalarQuantizer::select_quantizer() const {
1318
1277
  #ifdef USE_F16C
1319
1278
  if (d % 8 == 0) {
1320
- return select_quantizer_1<8> (qtype, d, trained);
1279
+ return select_quantizer_1<8>(qtype, d, trained);
1321
1280
  } else
1322
1281
  #endif
1323
1282
  {
1324
- return select_quantizer_1<1> (qtype, d, trained);
1283
+ return select_quantizer_1<1>(qtype, d, trained);
1325
1284
  }
1326
1285
  }
1327
1286
 
1287
+ void ScalarQuantizer::compute_codes(const float* x, uint8_t* codes, size_t n)
1288
+ const {
1289
+ std::unique_ptr<Quantizer> squant(select_quantizer());
1328
1290
 
1329
- void ScalarQuantizer::compute_codes (const float * x,
1330
- uint8_t * codes,
1331
- size_t n) const
1332
- {
1333
- std::unique_ptr<Quantizer> squant(select_quantizer ());
1334
-
1335
- memset (codes, 0, code_size * n);
1291
+ memset(codes, 0, code_size * n);
1336
1292
  #pragma omp parallel for
1337
1293
  for (int64_t i = 0; i < n; i++)
1338
- squant->encode_vector (x + i * d, codes + i * code_size);
1294
+ squant->encode_vector(x + i * d, codes + i * code_size);
1339
1295
  }
1340
1296
 
1341
- void ScalarQuantizer::decode (const uint8_t *codes, float *x, size_t n) const
1342
- {
1343
- std::unique_ptr<Quantizer> squant(select_quantizer ());
1297
+ void ScalarQuantizer::decode(const uint8_t* codes, float* x, size_t n) const {
1298
+ std::unique_ptr<Quantizer> squant(select_quantizer());
1344
1299
 
1345
1300
  #pragma omp parallel for
1346
1301
  for (int64_t i = 0; i < n; i++)
1347
- squant->decode_vector (codes + i * code_size, x + i * d);
1302
+ squant->decode_vector(codes + i * code_size, x + i * d);
1348
1303
  }
1349
1304
 
1350
-
1351
- SQDistanceComputer *
1352
- ScalarQuantizer::get_distance_computer (MetricType metric) const
1353
- {
1305
+ SQDistanceComputer* ScalarQuantizer::get_distance_computer(
1306
+ MetricType metric) const {
1354
1307
  FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1355
1308
  #ifdef USE_F16C
1356
1309
  if (d % 8 == 0) {
1357
1310
  if (metric == METRIC_L2) {
1358
- return select_distance_computer<SimilarityL2<8> >
1359
- (qtype, d, trained);
1311
+ return select_distance_computer<SimilarityL2<8>>(qtype, d, trained);
1360
1312
  } else {
1361
- return select_distance_computer<SimilarityIP<8> >
1362
- (qtype, d, trained);
1313
+ return select_distance_computer<SimilarityIP<8>>(qtype, d, trained);
1363
1314
  }
1364
1315
  } else
1365
1316
  #endif
1366
1317
  {
1367
1318
  if (metric == METRIC_L2) {
1368
- return select_distance_computer<SimilarityL2<1> >
1369
- (qtype, d, trained);
1319
+ return select_distance_computer<SimilarityL2<1>>(qtype, d, trained);
1370
1320
  } else {
1371
- return select_distance_computer<SimilarityIP<1> >
1372
- (qtype, d, trained);
1321
+ return select_distance_computer<SimilarityIP<1>>(qtype, d, trained);
1373
1322
  }
1374
1323
  }
1375
1324
  }
1376
1325
 
1377
-
1378
1326
  /*******************************************************************
1379
1327
  * IndexScalarQuantizer/IndexIVFScalarQuantizer scanner object
1380
1328
  *
@@ -1384,54 +1332,52 @@ ScalarQuantizer::get_distance_computer (MetricType metric) const
1384
1332
 
1385
1333
  namespace {
1386
1334
 
1387
-
1388
- template<class DCClass>
1389
- struct IVFSQScannerIP: InvertedListScanner {
1335
+ template <class DCClass>
1336
+ struct IVFSQScannerIP : InvertedListScanner {
1390
1337
  DCClass dc;
1391
- bool store_pairs, by_residual;
1392
-
1393
- size_t code_size;
1338
+ bool by_residual;
1394
1339
 
1395
- idx_t list_no; /// current list (set to 0 for Flat index
1396
- float accu0; /// added to all distances
1397
-
1398
- IVFSQScannerIP(int d, const std::vector<float> & trained,
1399
- size_t code_size, bool store_pairs,
1400
- bool by_residual):
1401
- dc(d, trained), store_pairs(store_pairs),
1402
- by_residual(by_residual),
1403
- code_size(code_size), list_no(0), accu0(0)
1404
- {}
1340
+ float accu0; /// added to all distances
1405
1341
 
1342
+ IVFSQScannerIP(
1343
+ int d,
1344
+ const std::vector<float>& trained,
1345
+ size_t code_size,
1346
+ bool store_pairs,
1347
+ bool by_residual)
1348
+ : dc(d, trained), by_residual(by_residual), accu0(0) {
1349
+ this->store_pairs = store_pairs;
1350
+ this->code_size = code_size;
1351
+ }
1406
1352
 
1407
- void set_query (const float *query) override {
1408
- dc.set_query (query);
1353
+ void set_query(const float* query) override {
1354
+ dc.set_query(query);
1409
1355
  }
1410
1356
 
1411
- void set_list (idx_t list_no, float coarse_dis) override {
1357
+ void set_list(idx_t list_no, float coarse_dis) override {
1412
1358
  this->list_no = list_no;
1413
1359
  accu0 = by_residual ? coarse_dis : 0;
1414
1360
  }
1415
1361
 
1416
- float distance_to_code (const uint8_t *code) const final {
1417
- return accu0 + dc.query_to_code (code);
1362
+ float distance_to_code(const uint8_t* code) const final {
1363
+ return accu0 + dc.query_to_code(code);
1418
1364
  }
1419
1365
 
1420
- size_t scan_codes (size_t list_size,
1421
- const uint8_t *codes,
1422
- const idx_t *ids,
1423
- float *simi, idx_t *idxi,
1424
- size_t k) const override
1425
- {
1366
+ size_t scan_codes(
1367
+ size_t list_size,
1368
+ const uint8_t* codes,
1369
+ const idx_t* ids,
1370
+ float* simi,
1371
+ idx_t* idxi,
1372
+ size_t k) const override {
1426
1373
  size_t nup = 0;
1427
1374
 
1428
1375
  for (size_t j = 0; j < list_size; j++) {
1376
+ float accu = accu0 + dc.query_to_code(codes);
1429
1377
 
1430
- float accu = accu0 + dc.query_to_code (codes);
1431
-
1432
- if (accu > simi [0]) {
1378
+ if (accu > simi[0]) {
1433
1379
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1434
- minheap_replace_top (k, simi, idxi, accu, id);
1380
+ minheap_replace_top(k, simi, idxi, accu, id);
1435
1381
  nup++;
1436
1382
  }
1437
1383
  codes += code_size;
@@ -1439,86 +1385,85 @@ struct IVFSQScannerIP: InvertedListScanner {
1439
1385
  return nup;
1440
1386
  }
1441
1387
 
1442
- void scan_codes_range (size_t list_size,
1443
- const uint8_t *codes,
1444
- const idx_t *ids,
1445
- float radius,
1446
- RangeQueryResult & res) const override
1447
- {
1388
+ void scan_codes_range(
1389
+ size_t list_size,
1390
+ const uint8_t* codes,
1391
+ const idx_t* ids,
1392
+ float radius,
1393
+ RangeQueryResult& res) const override {
1448
1394
  for (size_t j = 0; j < list_size; j++) {
1449
- float accu = accu0 + dc.query_to_code (codes);
1395
+ float accu = accu0 + dc.query_to_code(codes);
1450
1396
  if (accu > radius) {
1451
1397
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1452
- res.add (accu, id);
1398
+ res.add(accu, id);
1453
1399
  }
1454
1400
  codes += code_size;
1455
1401
  }
1456
1402
  }
1457
-
1458
-
1459
1403
  };
1460
1404
 
1461
-
1462
- template<class DCClass>
1463
- struct IVFSQScannerL2: InvertedListScanner {
1464
-
1405
+ template <class DCClass>
1406
+ struct IVFSQScannerL2 : InvertedListScanner {
1465
1407
  DCClass dc;
1466
1408
 
1467
- bool store_pairs, by_residual;
1468
- size_t code_size;
1469
- const Index *quantizer;
1470
- idx_t list_no; /// current inverted list
1471
- const float *x; /// current query
1409
+ bool by_residual;
1410
+ const Index* quantizer;
1411
+ const float* x; /// current query
1472
1412
 
1473
1413
  std::vector<float> tmp;
1474
1414
 
1475
- IVFSQScannerL2(int d, const std::vector<float> & trained,
1476
- size_t code_size, const Index *quantizer,
1477
- bool store_pairs, bool by_residual):
1478
- dc(d, trained), store_pairs(store_pairs), by_residual(by_residual),
1479
- code_size(code_size), quantizer(quantizer),
1480
- list_no (0), x (nullptr), tmp (d)
1481
- {
1482
- }
1483
-
1484
-
1485
- void set_query (const float *query) override {
1415
+ IVFSQScannerL2(
1416
+ int d,
1417
+ const std::vector<float>& trained,
1418
+ size_t code_size,
1419
+ const Index* quantizer,
1420
+ bool store_pairs,
1421
+ bool by_residual)
1422
+ : dc(d, trained),
1423
+ by_residual(by_residual),
1424
+ quantizer(quantizer),
1425
+ x(nullptr),
1426
+ tmp(d) {
1427
+ this->store_pairs = store_pairs;
1428
+ this->code_size = code_size;
1429
+ }
1430
+
1431
+ void set_query(const float* query) override {
1486
1432
  x = query;
1487
1433
  if (!quantizer) {
1488
- dc.set_query (query);
1434
+ dc.set_query(query);
1489
1435
  }
1490
1436
  }
1491
1437
 
1492
-
1493
- void set_list (idx_t list_no, float /*coarse_dis*/) override {
1438
+ void set_list(idx_t list_no, float /*coarse_dis*/) override {
1439
+ this->list_no = list_no;
1494
1440
  if (by_residual) {
1495
- this->list_no = list_no;
1496
1441
  // shift of x_in wrt centroid
1497
- quantizer->compute_residual (x, tmp.data(), list_no);
1498
- dc.set_query (tmp.data ());
1442
+ quantizer->compute_residual(x, tmp.data(), list_no);
1443
+ dc.set_query(tmp.data());
1499
1444
  } else {
1500
- dc.set_query (x);
1445
+ dc.set_query(x);
1501
1446
  }
1502
1447
  }
1503
1448
 
1504
- float distance_to_code (const uint8_t *code) const final {
1505
- return dc.query_to_code (code);
1449
+ float distance_to_code(const uint8_t* code) const final {
1450
+ return dc.query_to_code(code);
1506
1451
  }
1507
1452
 
1508
- size_t scan_codes (size_t list_size,
1509
- const uint8_t *codes,
1510
- const idx_t *ids,
1511
- float *simi, idx_t *idxi,
1512
- size_t k) const override
1513
- {
1453
+ size_t scan_codes(
1454
+ size_t list_size,
1455
+ const uint8_t* codes,
1456
+ const idx_t* ids,
1457
+ float* simi,
1458
+ idx_t* idxi,
1459
+ size_t k) const override {
1514
1460
  size_t nup = 0;
1515
1461
  for (size_t j = 0; j < list_size; j++) {
1462
+ float dis = dc.query_to_code(codes);
1516
1463
 
1517
- float dis = dc.query_to_code (codes);
1518
-
1519
- if (dis < simi [0]) {
1464
+ if (dis < simi[0]) {
1520
1465
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1521
- maxheap_replace_top (k, simi, idxi, dis, id);
1466
+ maxheap_replace_top(k, simi, idxi, dis, id);
1522
1467
  nup++;
1523
1468
  }
1524
1469
  codes += code_size;
@@ -1526,137 +1471,132 @@ struct IVFSQScannerL2: InvertedListScanner {
1526
1471
  return nup;
1527
1472
  }
1528
1473
 
1529
- void scan_codes_range (size_t list_size,
1530
- const uint8_t *codes,
1531
- const idx_t *ids,
1532
- float radius,
1533
- RangeQueryResult & res) const override
1534
- {
1474
+ void scan_codes_range(
1475
+ size_t list_size,
1476
+ const uint8_t* codes,
1477
+ const idx_t* ids,
1478
+ float radius,
1479
+ RangeQueryResult& res) const override {
1535
1480
  for (size_t j = 0; j < list_size; j++) {
1536
- float dis = dc.query_to_code (codes);
1481
+ float dis = dc.query_to_code(codes);
1537
1482
  if (dis < radius) {
1538
1483
  int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1539
- res.add (dis, id);
1484
+ res.add(dis, id);
1540
1485
  }
1541
1486
  codes += code_size;
1542
1487
  }
1543
1488
  }
1544
-
1545
-
1546
1489
  };
1547
1490
 
1548
- template<class DCClass>
1549
- InvertedListScanner* sel2_InvertedListScanner
1550
- (const ScalarQuantizer *sq,
1551
- const Index *quantizer, bool store_pairs, bool r)
1552
- {
1491
+ template <class DCClass>
1492
+ InvertedListScanner* sel2_InvertedListScanner(
1493
+ const ScalarQuantizer* sq,
1494
+ const Index* quantizer,
1495
+ bool store_pairs,
1496
+ bool r) {
1553
1497
  if (DCClass::Sim::metric_type == METRIC_L2) {
1554
- return new IVFSQScannerL2<DCClass>(sq->d, sq->trained, sq->code_size,
1555
- quantizer, store_pairs, r);
1498
+ return new IVFSQScannerL2<DCClass>(
1499
+ sq->d, sq->trained, sq->code_size, quantizer, store_pairs, r);
1556
1500
  } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
1557
- return new IVFSQScannerIP<DCClass>(sq->d, sq->trained, sq->code_size,
1558
- store_pairs, r);
1501
+ return new IVFSQScannerIP<DCClass>(
1502
+ sq->d, sq->trained, sq->code_size, store_pairs, r);
1559
1503
  } else {
1560
1504
  FAISS_THROW_MSG("unsupported metric type");
1561
1505
  }
1562
1506
  }
1563
1507
 
1564
- template<class Similarity, class Codec, bool uniform>
1565
- InvertedListScanner* sel12_InvertedListScanner
1566
- (const ScalarQuantizer *sq,
1567
- const Index *quantizer, bool store_pairs, bool r)
1568
- {
1508
+ template <class Similarity, class Codec, bool uniform>
1509
+ InvertedListScanner* sel12_InvertedListScanner(
1510
+ const ScalarQuantizer* sq,
1511
+ const Index* quantizer,
1512
+ bool store_pairs,
1513
+ bool r) {
1569
1514
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1570
1515
  using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
1571
1516
  using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
1572
- return sel2_InvertedListScanner<DCClass> (sq, quantizer, store_pairs, r);
1517
+ return sel2_InvertedListScanner<DCClass>(sq, quantizer, store_pairs, r);
1573
1518
  }
1574
1519
 
1575
-
1576
-
1577
- template<class Similarity>
1578
- InvertedListScanner* sel1_InvertedListScanner
1579
- (const ScalarQuantizer *sq, const Index *quantizer,
1580
- bool store_pairs, bool r)
1581
- {
1520
+ template <class Similarity>
1521
+ InvertedListScanner* sel1_InvertedListScanner(
1522
+ const ScalarQuantizer* sq,
1523
+ const Index* quantizer,
1524
+ bool store_pairs,
1525
+ bool r) {
1582
1526
  constexpr int SIMDWIDTH = Similarity::simdwidth;
1583
- switch(sq->qtype) {
1584
- case ScalarQuantizer::QT_8bit_uniform:
1585
- return sel12_InvertedListScanner
1586
- <Similarity, Codec8bit, true>(sq, quantizer, store_pairs, r);
1587
- case ScalarQuantizer::QT_4bit_uniform:
1588
- return sel12_InvertedListScanner
1589
- <Similarity, Codec4bit, true>(sq, quantizer, store_pairs, r);
1590
- case ScalarQuantizer::QT_8bit:
1591
- return sel12_InvertedListScanner
1592
- <Similarity, Codec8bit, false>(sq, quantizer, store_pairs, r);
1593
- case ScalarQuantizer::QT_4bit:
1594
- return sel12_InvertedListScanner
1595
- <Similarity, Codec4bit, false>(sq, quantizer, store_pairs, r);
1596
- case ScalarQuantizer::QT_6bit:
1597
- return sel12_InvertedListScanner
1598
- <Similarity, Codec6bit, false>(sq, quantizer, store_pairs, r);
1599
- case ScalarQuantizer::QT_fp16:
1600
- return sel2_InvertedListScanner
1601
- <DCTemplate<QuantizerFP16<SIMDWIDTH>, Similarity, SIMDWIDTH> >
1602
- (sq, quantizer, store_pairs, r);
1603
- case ScalarQuantizer::QT_8bit_direct:
1604
- if (sq->d % 16 == 0) {
1605
- return sel2_InvertedListScanner
1606
- <DistanceComputerByte<Similarity, SIMDWIDTH> >
1607
- (sq, quantizer, store_pairs, r);
1608
- } else {
1609
- return sel2_InvertedListScanner
1610
- <DCTemplate<Quantizer8bitDirect<SIMDWIDTH>,
1611
- Similarity, SIMDWIDTH> >
1612
- (sq, quantizer, store_pairs, r);
1613
- }
1614
-
1527
+ switch (sq->qtype) {
1528
+ case ScalarQuantizer::QT_8bit_uniform:
1529
+ return sel12_InvertedListScanner<Similarity, Codec8bit, true>(
1530
+ sq, quantizer, store_pairs, r);
1531
+ case ScalarQuantizer::QT_4bit_uniform:
1532
+ return sel12_InvertedListScanner<Similarity, Codec4bit, true>(
1533
+ sq, quantizer, store_pairs, r);
1534
+ case ScalarQuantizer::QT_8bit:
1535
+ return sel12_InvertedListScanner<Similarity, Codec8bit, false>(
1536
+ sq, quantizer, store_pairs, r);
1537
+ case ScalarQuantizer::QT_4bit:
1538
+ return sel12_InvertedListScanner<Similarity, Codec4bit, false>(
1539
+ sq, quantizer, store_pairs, r);
1540
+ case ScalarQuantizer::QT_6bit:
1541
+ return sel12_InvertedListScanner<Similarity, Codec6bit, false>(
1542
+ sq, quantizer, store_pairs, r);
1543
+ case ScalarQuantizer::QT_fp16:
1544
+ return sel2_InvertedListScanner<DCTemplate<
1545
+ QuantizerFP16<SIMDWIDTH>,
1546
+ Similarity,
1547
+ SIMDWIDTH>>(sq, quantizer, store_pairs, r);
1548
+ case ScalarQuantizer::QT_8bit_direct:
1549
+ if (sq->d % 16 == 0) {
1550
+ return sel2_InvertedListScanner<
1551
+ DistanceComputerByte<Similarity, SIMDWIDTH>>(
1552
+ sq, quantizer, store_pairs, r);
1553
+ } else {
1554
+ return sel2_InvertedListScanner<DCTemplate<
1555
+ Quantizer8bitDirect<SIMDWIDTH>,
1556
+ Similarity,
1557
+ SIMDWIDTH>>(sq, quantizer, store_pairs, r);
1558
+ }
1615
1559
  }
1616
1560
 
1617
- FAISS_THROW_MSG ("unknown qtype");
1561
+ FAISS_THROW_MSG("unknown qtype");
1618
1562
  return nullptr;
1619
1563
  }
1620
1564
 
1621
- template<int SIMDWIDTH>
1622
- InvertedListScanner* sel0_InvertedListScanner
1623
- (MetricType mt, const ScalarQuantizer *sq,
1624
- const Index *quantizer, bool store_pairs, bool by_residual)
1625
- {
1565
+ template <int SIMDWIDTH>
1566
+ InvertedListScanner* sel0_InvertedListScanner(
1567
+ MetricType mt,
1568
+ const ScalarQuantizer* sq,
1569
+ const Index* quantizer,
1570
+ bool store_pairs,
1571
+ bool by_residual) {
1626
1572
  if (mt == METRIC_L2) {
1627
- return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH> >
1628
- (sq, quantizer, store_pairs, by_residual);
1573
+ return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH>>(
1574
+ sq, quantizer, store_pairs, by_residual);
1629
1575
  } else if (mt == METRIC_INNER_PRODUCT) {
1630
- return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH> >
1631
- (sq, quantizer, store_pairs, by_residual);
1576
+ return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH>>(
1577
+ sq, quantizer, store_pairs, by_residual);
1632
1578
  } else {
1633
1579
  FAISS_THROW_MSG("unsupported metric type");
1634
1580
  }
1635
1581
  }
1636
1582
 
1637
-
1638
-
1639
1583
  } // anonymous namespace
1640
1584
 
1641
-
1642
- InvertedListScanner* ScalarQuantizer::select_InvertedListScanner
1643
- (MetricType mt, const Index *quantizer,
1644
- bool store_pairs, bool by_residual) const
1645
- {
1585
+ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner(
1586
+ MetricType mt,
1587
+ const Index* quantizer,
1588
+ bool store_pairs,
1589
+ bool by_residual) const {
1646
1590
  #ifdef USE_F16C
1647
1591
  if (d % 8 == 0) {
1648
- return sel0_InvertedListScanner<8>
1649
- (mt, this, quantizer, store_pairs, by_residual);
1592
+ return sel0_InvertedListScanner<8>(
1593
+ mt, this, quantizer, store_pairs, by_residual);
1650
1594
  } else
1651
1595
  #endif
1652
1596
  {
1653
- return sel0_InvertedListScanner<1>
1654
- (mt, this, quantizer, store_pairs, by_residual);
1597
+ return sel0_InvertedListScanner<1>(
1598
+ mt, this, quantizer, store_pairs, by_residual);
1655
1599
  }
1656
1600
  }
1657
1601
 
1658
-
1659
-
1660
-
1661
-
1662
1602
  } // namespace faiss