faiss 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,242 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_PRODUCT_QUANTIZER_H
11
+ #define FAISS_PRODUCT_QUANTIZER_H
12
+
13
+ #include <stdint.h>
14
+
15
+ #include <vector>
16
+
17
+ #include <faiss/Clustering.h>
18
+ #include <faiss/utils/Heap.h>
19
+
20
+ namespace faiss {
21
+
22
+ /** Product Quantizer. Implemented only for METRIC_L2 */
23
+ struct ProductQuantizer {
24
+
25
+ using idx_t = Index::idx_t;
26
+
27
+ size_t d; ///< size of the input vectors
28
+ size_t M; ///< number of subquantizers
29
+ size_t nbits; ///< number of bits per quantization index
30
+
31
+ // values derived from the above
32
+ size_t dsub; ///< dimensionality of each subvector
33
+ size_t code_size; ///< bytes per indexed vector
34
+ size_t ksub; ///< number of centroids for each subquantizer
35
+ bool verbose; ///< verbose during training?
36
+
37
+ /// initialization
38
+ enum train_type_t {
39
+ Train_default,
40
+ Train_hot_start, ///< the centroids are already initialized
41
+ Train_shared, ///< share dictionary accross PQ segments
42
+ Train_hypercube, ///< intialize centroids with nbits-D hypercube
43
+ Train_hypercube_pca, ///< intialize centroids with nbits-D hypercube
44
+ };
45
+ train_type_t train_type;
46
+
47
+ ClusteringParameters cp; ///< parameters used during clustering
48
+
49
+ /// if non-NULL, use this index for assignment (should be of size
50
+ /// d / M)
51
+ Index *assign_index;
52
+
53
+ /// Centroid table, size M * ksub * dsub
54
+ std::vector<float> centroids;
55
+
56
+ /// return the centroids associated with subvector m
57
+ float * get_centroids (size_t m, size_t i) {
58
+ return &centroids [(m * ksub + i) * dsub];
59
+ }
60
+ const float * get_centroids (size_t m, size_t i) const {
61
+ return &centroids [(m * ksub + i) * dsub];
62
+ }
63
+
64
+ // Train the product quantizer on a set of points. A clustering
65
+ // can be set on input to define non-default clustering parameters
66
+ void train (int n, const float *x);
67
+
68
+ ProductQuantizer(size_t d, /* dimensionality of the input vectors */
69
+ size_t M, /* number of subquantizers */
70
+ size_t nbits); /* number of bit per subvector index */
71
+
72
+ ProductQuantizer ();
73
+
74
+ /// compute derived values when d, M and nbits have been set
75
+ void set_derived_values ();
76
+
77
+ /// Define the centroids for subquantizer m
78
+ void set_params (const float * centroids, int m);
79
+
80
+ /// Quantize one vector with the product quantizer
81
+ void compute_code (const float * x, uint8_t * code) const ;
82
+
83
+ /// same as compute_code for several vectors
84
+ void compute_codes (const float * x,
85
+ uint8_t * codes,
86
+ size_t n) const ;
87
+
88
+ /// speed up code assignment using assign_index
89
+ /// (non-const because the index is changed)
90
+ void compute_codes_with_assign_index (
91
+ const float * x,
92
+ uint8_t * codes,
93
+ size_t n);
94
+
95
+ /// decode a vector from a given code (or n vectors if third argument)
96
+ void decode (const uint8_t *code, float *x) const;
97
+ void decode (const uint8_t *code, float *x, size_t n) const;
98
+
99
+ /// If we happen to have the distance tables precomputed, this is
100
+ /// more efficient to compute the codes.
101
+ void compute_code_from_distance_table (const float *tab,
102
+ uint8_t *code) const;
103
+
104
+
105
+ /** Compute distance table for one vector.
106
+ *
107
+ * The distance table for x = [x_0 x_1 .. x_(M-1)] is a M * ksub
108
+ * matrix that contains
109
+ *
110
+ * dis_table (m, j) = || x_m - c_(m, j)||^2
111
+ * for m = 0..M-1 and j = 0 .. ksub - 1
112
+ *
113
+ * where c_(m, j) is the centroid no j of sub-quantizer m.
114
+ *
115
+ * @param x input vector size d
116
+ * @param dis_table output table, size M * ksub
117
+ */
118
+ void compute_distance_table (const float * x,
119
+ float * dis_table) const;
120
+
121
+ void compute_inner_prod_table (const float * x,
122
+ float * dis_table) const;
123
+
124
+
125
+ /** compute distance table for several vectors
126
+ * @param nx nb of input vectors
127
+ * @param x input vector size nx * d
128
+ * @param dis_table output table, size nx * M * ksub
129
+ */
130
+ void compute_distance_tables (size_t nx,
131
+ const float * x,
132
+ float * dis_tables) const;
133
+
134
+ void compute_inner_prod_tables (size_t nx,
135
+ const float * x,
136
+ float * dis_tables) const;
137
+
138
+
139
+ /** perform a search (L2 distance)
140
+ * @param x query vectors, size nx * d
141
+ * @param nx nb of queries
142
+ * @param codes database codes, size ncodes * code_size
143
+ * @param ncodes nb of nb vectors
144
+ * @param res heap array to store results (nh == nx)
145
+ * @param init_finalize_heap initialize heap (input) and sort (output)?
146
+ */
147
+ void search (const float * x,
148
+ size_t nx,
149
+ const uint8_t * codes,
150
+ const size_t ncodes,
151
+ float_maxheap_array_t *res,
152
+ bool init_finalize_heap = true) const;
153
+
154
+ /** same as search, but with inner product similarity */
155
+ void search_ip (const float * x,
156
+ size_t nx,
157
+ const uint8_t * codes,
158
+ const size_t ncodes,
159
+ float_minheap_array_t *res,
160
+ bool init_finalize_heap = true) const;
161
+
162
+
163
+ /// Symmetric Distance Table
164
+ std::vector<float> sdc_table;
165
+
166
+ // intitialize the SDC table from the centroids
167
+ void compute_sdc_table ();
168
+
169
+ void search_sdc (const uint8_t * qcodes,
170
+ size_t nq,
171
+ const uint8_t * bcodes,
172
+ const size_t ncodes,
173
+ float_maxheap_array_t * res,
174
+ bool init_finalize_heap = true) const;
175
+
176
+ struct PQEncoderGeneric {
177
+ uint8_t *code; ///< code for this vector
178
+ uint8_t offset;
179
+ const int nbits; ///< number of bits per subquantizer index
180
+
181
+ uint8_t reg;
182
+
183
+ PQEncoderGeneric(uint8_t *code, int nbits, uint8_t offset = 0);
184
+
185
+ void encode(uint64_t x);
186
+
187
+ ~PQEncoderGeneric();
188
+ };
189
+
190
+
191
+ struct PQEncoder8 {
192
+ uint8_t *code;
193
+
194
+ PQEncoder8(uint8_t *code, int nbits);
195
+
196
+ void encode(uint64_t x);
197
+ };
198
+
199
+ struct PQEncoder16 {
200
+ uint16_t *code;
201
+
202
+ PQEncoder16(uint8_t *code, int nbits);
203
+
204
+ void encode(uint64_t x);
205
+ };
206
+
207
+
208
+ struct PQDecoderGeneric {
209
+ const uint8_t *code;
210
+ uint8_t offset;
211
+ const int nbits;
212
+ const uint64_t mask;
213
+ uint8_t reg;
214
+
215
+ PQDecoderGeneric(const uint8_t *code, int nbits);
216
+
217
+ uint64_t decode();
218
+ };
219
+
220
+ struct PQDecoder8 {
221
+ const uint8_t *code;
222
+
223
+ PQDecoder8(const uint8_t *code, int nbits);
224
+
225
+ uint64_t decode();
226
+ };
227
+
228
+ struct PQDecoder16 {
229
+ const uint16_t *code;
230
+
231
+ PQDecoder16(const uint8_t *code, int nbits);
232
+
233
+ uint64_t decode();
234
+ };
235
+
236
+ };
237
+
238
+
239
+ } // namespace faiss
240
+
241
+
242
+ #endif
@@ -0,0 +1,1628 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/ScalarQuantizer.h>
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+
15
+ #include <omp.h>
16
+
17
+ #ifdef __SSE__
18
+ #include <immintrin.h>
19
+ #endif
20
+
21
+ #include <faiss/utils/utils.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/impl/AuxIndexStructures.h>
24
+
25
+ namespace faiss {
26
+
27
+ /*******************************************************************
28
+ * ScalarQuantizer implementation
29
+ *
30
+ * The main source of complexity is to support combinations of 4
31
+ * variants without incurring runtime tests or virtual function calls:
32
+ *
33
+ * - 4 / 8 bits per code component
34
+ * - uniform / non-uniform
35
+ * - IP / L2 distance search
36
+ * - scalar / AVX distance computation
37
+ *
38
+ * The appropriate Quantizer object is returned via select_quantizer
39
+ * that hides the template mess.
40
+ ********************************************************************/
41
+
42
+ #ifdef __AVX__
43
+ #define USE_AVX
44
+ #endif
45
+
46
+ #ifdef __F16C__
47
+ #define USE_F16C
48
+ #endif
49
+
50
+
51
+ namespace {
52
+
53
+ typedef Index::idx_t idx_t;
54
+ typedef ScalarQuantizer::QuantizerType QuantizerType;
55
+ typedef ScalarQuantizer::RangeStat RangeStat;
56
+ using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
57
+
58
+
59
+ /*******************************************************************
60
+ * Codec: converts between values in [0, 1] and an index in a code
61
+ * array. The "i" parameter is the vector component index (not byte
62
+ * index).
63
+ */
64
+
65
+ struct Codec8bit {
66
+
67
+ static void encode_component (float x, uint8_t *code, int i) {
68
+ code[i] = (int)(255 * x);
69
+ }
70
+
71
+ static float decode_component (const uint8_t *code, int i) {
72
+ return (code[i] + 0.5f) / 255.0f;
73
+ }
74
+
75
+ #ifdef USE_AVX
76
+ static __m256 decode_8_components (const uint8_t *code, int i) {
77
+ uint64_t c8 = *(uint64_t*)(code + i);
78
+ __m128i c4lo = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8));
79
+ __m128i c4hi = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8 >> 32));
80
+ // __m256i i8 = _mm256_set_m128i(c4lo, c4hi);
81
+ __m256i i8 = _mm256_castsi128_si256 (c4lo);
82
+ i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
83
+ __m256 f8 = _mm256_cvtepi32_ps (i8);
84
+ __m256 half = _mm256_set1_ps (0.5f);
85
+ f8 += half;
86
+ __m256 one_255 = _mm256_set1_ps (1.f / 255.f);
87
+ return f8 * one_255;
88
+ }
89
+ #endif
90
+ };
91
+
92
+
93
+ struct Codec4bit {
94
+
95
+ static void encode_component (float x, uint8_t *code, int i) {
96
+ code [i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
97
+ }
98
+
99
+ static float decode_component (const uint8_t *code, int i) {
100
+ return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
101
+ }
102
+
103
+
104
+ #ifdef USE_AVX
105
+ static __m256 decode_8_components (const uint8_t *code, int i) {
106
+ uint32_t c4 = *(uint32_t*)(code + (i >> 1));
107
+ uint32_t mask = 0x0f0f0f0f;
108
+ uint32_t c4ev = c4 & mask;
109
+ uint32_t c4od = (c4 >> 4) & mask;
110
+
111
+ // the 8 lower bytes of c8 contain the values
112
+ __m128i c8 = _mm_unpacklo_epi8 (_mm_set1_epi32(c4ev),
113
+ _mm_set1_epi32(c4od));
114
+ __m128i c4lo = _mm_cvtepu8_epi32 (c8);
115
+ __m128i c4hi = _mm_cvtepu8_epi32 (_mm_srli_si128(c8, 4));
116
+ __m256i i8 = _mm256_castsi128_si256 (c4lo);
117
+ i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
118
+ __m256 f8 = _mm256_cvtepi32_ps (i8);
119
+ __m256 half = _mm256_set1_ps (0.5f);
120
+ f8 += half;
121
+ __m256 one_255 = _mm256_set1_ps (1.f / 15.f);
122
+ return f8 * one_255;
123
+ }
124
+ #endif
125
+ };
126
+
127
+ struct Codec6bit {
128
+
129
+ static void encode_component (float x, uint8_t *code, int i) {
130
+ int bits = (int)(x * 63.0);
131
+ code += (i >> 2) * 3;
132
+ switch(i & 3) {
133
+ case 0:
134
+ code[0] |= bits;
135
+ break;
136
+ case 1:
137
+ code[0] |= bits << 6;
138
+ code[1] |= bits >> 2;
139
+ break;
140
+ case 2:
141
+ code[1] |= bits << 4;
142
+ code[2] |= bits >> 4;
143
+ break;
144
+ case 3:
145
+ code[2] |= bits << 2;
146
+ break;
147
+ }
148
+ }
149
+
150
+ static float decode_component (const uint8_t *code, int i) {
151
+ uint8_t bits;
152
+ code += (i >> 2) * 3;
153
+ switch(i & 3) {
154
+ case 0:
155
+ bits = code[0] & 0x3f;
156
+ break;
157
+ case 1:
158
+ bits = code[0] >> 6;
159
+ bits |= (code[1] & 0xf) << 2;
160
+ break;
161
+ case 2:
162
+ bits = code[1] >> 4;
163
+ bits |= (code[2] & 3) << 4;
164
+ break;
165
+ case 3:
166
+ bits = code[2] >> 2;
167
+ break;
168
+ }
169
+ return (bits + 0.5f) / 63.0f;
170
+ }
171
+
172
+ #ifdef USE_AVX
173
+ static __m256 decode_8_components (const uint8_t *code, int i) {
174
+ return _mm256_set_ps
175
+ (decode_component(code, i + 7),
176
+ decode_component(code, i + 6),
177
+ decode_component(code, i + 5),
178
+ decode_component(code, i + 4),
179
+ decode_component(code, i + 3),
180
+ decode_component(code, i + 2),
181
+ decode_component(code, i + 1),
182
+ decode_component(code, i + 0));
183
+ }
184
+ #endif
185
+ };
186
+
187
+
188
+
189
+ #ifdef USE_F16C
190
+
191
+
192
+ uint16_t encode_fp16 (float x) {
193
+ __m128 xf = _mm_set1_ps (x);
194
+ __m128i xi = _mm_cvtps_ph (
195
+ xf, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
196
+ return _mm_cvtsi128_si32 (xi) & 0xffff;
197
+ }
198
+
199
+
200
+ float decode_fp16 (uint16_t x) {
201
+ __m128i xi = _mm_set1_epi16 (x);
202
+ __m128 xf = _mm_cvtph_ps (xi);
203
+ return _mm_cvtss_f32 (xf);
204
+ }
205
+
206
+ #else
207
+
208
+ // non-intrinsic FP16 <-> FP32 code adapted from
209
+ // https://github.com/ispc/ispc/blob/master/stdlib.ispc
210
+
211
+ float floatbits (uint32_t x) {
212
+ void *xptr = &x;
213
+ return *(float*)xptr;
214
+ }
215
+
216
+ uint32_t intbits (float f) {
217
+ void *fptr = &f;
218
+ return *(uint32_t*)fptr;
219
+ }
220
+
221
+
222
+ uint16_t encode_fp16 (float f) {
223
+
224
+ // via Fabian "ryg" Giesen.
225
+ // https://gist.github.com/2156668
226
+ uint32_t sign_mask = 0x80000000u;
227
+ int32_t o;
228
+
229
+ uint32_t fint = intbits(f);
230
+ uint32_t sign = fint & sign_mask;
231
+ fint ^= sign;
232
+
233
+ // NOTE all the integer compares in this function can be safely
234
+ // compiled into signed compares since all operands are below
235
+ // 0x80000000. Important if you want fast straight SSE2 code (since
236
+ // there's no unsigned PCMPGTD).
237
+
238
+ // Inf or NaN (all exponent bits set)
239
+ // NaN->qNaN and Inf->Inf
240
+ // unconditional assignment here, will override with right value for
241
+ // the regular case below.
242
+ uint32_t f32infty = 255u << 23;
243
+ o = (fint > f32infty) ? 0x7e00u : 0x7c00u;
244
+
245
+ // (De)normalized number or zero
246
+ // update fint unconditionally to save the blending; we don't need it
247
+ // anymore for the Inf/NaN case anyway.
248
+
249
+ const uint32_t round_mask = ~0xfffu;
250
+ const uint32_t magic = 15u << 23;
251
+
252
+ // Shift exponent down, denormalize if necessary.
253
+ // NOTE This represents half-float denormals using single
254
+ // precision denormals. The main reason to do this is that
255
+ // there's no shift with per-lane variable shifts in SSE*, which
256
+ // we'd otherwise need. It has some funky side effects though:
257
+ // - This conversion will actually respect the FTZ (Flush To Zero)
258
+ // flag in MXCSR - if it's set, no half-float denormals will be
259
+ // generated. I'm honestly not sure whether this is good or
260
+ // bad. It's definitely interesting.
261
+ // - If the underlying HW doesn't support denormals (not an issue
262
+ // with Intel CPUs, but might be a problem on GPUs or PS3 SPUs),
263
+ // you will always get flush-to-zero behavior. This is bad,
264
+ // unless you're on a CPU where you don't care.
265
+ // - Denormals tend to be slow. FP32 denormals are rare in
266
+ // practice outside of things like recursive filters in DSP -
267
+ // not a typical half-float application. Whether FP16 denormals
268
+ // are rare in practice, I don't know. Whatever slow path your
269
+ // HW may or may not have for denormals, this may well hit it.
270
+ float fscale = floatbits(fint & round_mask) * floatbits(magic);
271
+ fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u));
272
+ int32_t fint2 = intbits(fscale) - round_mask;
273
+
274
+ if (fint < f32infty)
275
+ o = fint2 >> 13; // Take the bits!
276
+
277
+ return (o | (sign >> 16));
278
+ }
279
+
280
+ float decode_fp16 (uint16_t h) {
281
+
282
+ // https://gist.github.com/2144712
283
+ // Fabian "ryg" Giesen.
284
+
285
+ const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
286
+
287
+ int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
288
+ int32_t exp = shifted_exp & o; // just the exponent
289
+ o += (int32_t)(127 - 15) << 23; // exponent adjust
290
+
291
+ int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
292
+ int32_t zerodenorm_val = intbits(
293
+ floatbits(o + (1u<<23)) - floatbits(113u << 23));
294
+ int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
295
+
296
+ int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
297
+ return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit);
298
+ }
299
+
300
+ #endif
301
+
302
+
303
+
304
+ /*******************************************************************
305
+ * Quantizer: normalizes scalar vector components, then passes them
306
+ * through a codec
307
+ *******************************************************************/
308
+
309
+
310
+
311
+
312
+
313
+ template<class Codec, bool uniform, int SIMD>
314
+ struct QuantizerTemplate {};
315
+
316
+
317
+ template<class Codec>
318
+ struct QuantizerTemplate<Codec, true, 1>: ScalarQuantizer::Quantizer {
319
+ const size_t d;
320
+ const float vmin, vdiff;
321
+
322
+ QuantizerTemplate(size_t d, const std::vector<float> &trained):
323
+ d(d), vmin(trained[0]), vdiff(trained[1])
324
+ {
325
+ }
326
+
327
+ void encode_vector(const float* x, uint8_t* code) const final {
328
+ for (size_t i = 0; i < d; i++) {
329
+ float xi = (x[i] - vmin) / vdiff;
330
+ if (xi < 0) {
331
+ xi = 0;
332
+ }
333
+ if (xi > 1.0) {
334
+ xi = 1.0;
335
+ }
336
+ Codec::encode_component(xi, code, i);
337
+ }
338
+ }
339
+
340
+ void decode_vector(const uint8_t* code, float* x) const final {
341
+ for (size_t i = 0; i < d; i++) {
342
+ float xi = Codec::decode_component(code, i);
343
+ x[i] = vmin + xi * vdiff;
344
+ }
345
+ }
346
+
347
+ float reconstruct_component (const uint8_t * code, int i) const
348
+ {
349
+ float xi = Codec::decode_component (code, i);
350
+ return vmin + xi * vdiff;
351
+ }
352
+
353
+ };
354
+
355
+
356
+
357
+ #ifdef USE_AVX
358
+
359
+ template<class Codec>
360
+ struct QuantizerTemplate<Codec, true, 8>: QuantizerTemplate<Codec, true, 1> {
361
+
362
+ QuantizerTemplate (size_t d, const std::vector<float> &trained):
363
+ QuantizerTemplate<Codec, true, 1> (d, trained) {}
364
+
365
+ __m256 reconstruct_8_components (const uint8_t * code, int i) const
366
+ {
367
+ __m256 xi = Codec::decode_8_components (code, i);
368
+ return _mm256_set1_ps(this->vmin) + xi * _mm256_set1_ps (this->vdiff);
369
+ }
370
+
371
+ };
372
+
373
+ #endif
374
+
375
+
376
+
377
+ template<class Codec>
378
+ struct QuantizerTemplate<Codec, false, 1>: ScalarQuantizer::Quantizer {
379
+ const size_t d;
380
+ const float *vmin, *vdiff;
381
+
382
+ QuantizerTemplate (size_t d, const std::vector<float> &trained):
383
+ d(d), vmin(trained.data()), vdiff(trained.data() + d) {}
384
+
385
+ void encode_vector(const float* x, uint8_t* code) const final {
386
+ for (size_t i = 0; i < d; i++) {
387
+ float xi = (x[i] - vmin[i]) / vdiff[i];
388
+ if (xi < 0)
389
+ xi = 0;
390
+ if (xi > 1.0)
391
+ xi = 1.0;
392
+ Codec::encode_component(xi, code, i);
393
+ }
394
+ }
395
+
396
+ void decode_vector(const uint8_t* code, float* x) const final {
397
+ for (size_t i = 0; i < d; i++) {
398
+ float xi = Codec::decode_component(code, i);
399
+ x[i] = vmin[i] + xi * vdiff[i];
400
+ }
401
+ }
402
+
403
+ float reconstruct_component (const uint8_t * code, int i) const
404
+ {
405
+ float xi = Codec::decode_component (code, i);
406
+ return vmin[i] + xi * vdiff[i];
407
+ }
408
+
409
+ };
410
+
411
+
412
+ #ifdef USE_AVX
413
+
414
+ template<class Codec>
415
+ struct QuantizerTemplate<Codec, false, 8>: QuantizerTemplate<Codec, false, 1> {
416
+
417
+ QuantizerTemplate (size_t d, const std::vector<float> &trained):
418
+ QuantizerTemplate<Codec, false, 1> (d, trained) {}
419
+
420
+ __m256 reconstruct_8_components (const uint8_t * code, int i) const
421
+ {
422
+ __m256 xi = Codec::decode_8_components (code, i);
423
+ return _mm256_loadu_ps (this->vmin + i) + xi * _mm256_loadu_ps (this->vdiff + i);
424
+ }
425
+
426
+
427
+ };
428
+
429
+ #endif
430
+
431
+ /*******************************************************************
432
+ * FP16 quantizer
433
+ *******************************************************************/
434
+
435
+ template<int SIMDWIDTH>
436
+ struct QuantizerFP16 {};
437
+
438
+ template<>
439
+ struct QuantizerFP16<1>: ScalarQuantizer::Quantizer {
440
+ const size_t d;
441
+
442
+ QuantizerFP16(size_t d, const std::vector<float> & /* unused */):
443
+ d(d) {}
444
+
445
+ void encode_vector(const float* x, uint8_t* code) const final {
446
+ for (size_t i = 0; i < d; i++) {
447
+ ((uint16_t*)code)[i] = encode_fp16(x[i]);
448
+ }
449
+ }
450
+
451
+ void decode_vector(const uint8_t* code, float* x) const final {
452
+ for (size_t i = 0; i < d; i++) {
453
+ x[i] = decode_fp16(((uint16_t*)code)[i]);
454
+ }
455
+ }
456
+
457
+ float reconstruct_component (const uint8_t * code, int i) const
458
+ {
459
+ return decode_fp16(((uint16_t*)code)[i]);
460
+ }
461
+
462
+ };
463
+
464
+ #ifdef USE_F16C
465
+
466
+ template<>
467
+ struct QuantizerFP16<8>: QuantizerFP16<1> {
468
+
469
+ QuantizerFP16 (size_t d, const std::vector<float> &trained):
470
+ QuantizerFP16<1> (d, trained) {}
471
+
472
+ __m256 reconstruct_8_components (const uint8_t * code, int i) const
473
+ {
474
+ __m128i codei = _mm_loadu_si128 ((const __m128i*)(code + 2 * i));
475
+ return _mm256_cvtph_ps (codei);
476
+ }
477
+
478
+ };
479
+
480
+ #endif
481
+
482
+ /*******************************************************************
483
+ * 8bit_direct quantizer
484
+ *******************************************************************/
485
+
486
+ template<int SIMDWIDTH>
487
+ struct Quantizer8bitDirect {};
488
+
489
+ template<>
490
+ struct Quantizer8bitDirect<1>: ScalarQuantizer::Quantizer {
491
+ const size_t d;
492
+
493
+ Quantizer8bitDirect(size_t d, const std::vector<float> & /* unused */):
494
+ d(d) {}
495
+
496
+
497
+ void encode_vector(const float* x, uint8_t* code) const final {
498
+ for (size_t i = 0; i < d; i++) {
499
+ code[i] = (uint8_t)x[i];
500
+ }
501
+ }
502
+
503
+ void decode_vector(const uint8_t* code, float* x) const final {
504
+ for (size_t i = 0; i < d; i++) {
505
+ x[i] = code[i];
506
+ }
507
+ }
508
+
509
+ float reconstruct_component (const uint8_t * code, int i) const
510
+ {
511
+ return code[i];
512
+ }
513
+
514
+ };
515
+
516
+ #ifdef USE_AVX
517
+
518
+ template<>
519
+ struct Quantizer8bitDirect<8>: Quantizer8bitDirect<1> {
520
+
521
+ Quantizer8bitDirect (size_t d, const std::vector<float> &trained):
522
+ Quantizer8bitDirect<1> (d, trained) {}
523
+
524
+ __m256 reconstruct_8_components (const uint8_t * code, int i) const
525
+ {
526
+ __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
527
+ __m256i y8 = _mm256_cvtepu8_epi32 (x8); // 8 * int32
528
+ return _mm256_cvtepi32_ps (y8); // 8 * float32
529
+ }
530
+
531
+ };
532
+
533
+ #endif
534
+
535
+
536
+ template<int SIMDWIDTH>
537
+ ScalarQuantizer::Quantizer *select_quantizer_1 (
538
+ QuantizerType qtype,
539
+ size_t d, const std::vector<float> & trained)
540
+ {
541
+ switch(qtype) {
542
+ case ScalarQuantizer::QT_8bit:
543
+ return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(d, trained);
544
+ case ScalarQuantizer::QT_6bit:
545
+ return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(d, trained);
546
+ case ScalarQuantizer::QT_4bit:
547
+ return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(d, trained);
548
+ case ScalarQuantizer::QT_8bit_uniform:
549
+ return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(d, trained);
550
+ case ScalarQuantizer::QT_4bit_uniform:
551
+ return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(d, trained);
552
+ case ScalarQuantizer::QT_fp16:
553
+ return new QuantizerFP16<SIMDWIDTH> (d, trained);
554
+ case ScalarQuantizer::QT_8bit_direct:
555
+ return new Quantizer8bitDirect<SIMDWIDTH> (d, trained);
556
+ }
557
+ FAISS_THROW_MSG ("unknown qtype");
558
+ }
559
+
560
+
561
+
562
+
563
+ /*******************************************************************
564
+ * Quantizer range training
565
+ */
566
+
567
+ static float sqr (float x) {
568
+ return x * x;
569
+ }
570
+
571
+
572
+ void train_Uniform(RangeStat rs, float rs_arg,
573
+ idx_t n, int k, const float *x,
574
+ std::vector<float> & trained)
575
+ {
576
+ trained.resize (2);
577
+ float & vmin = trained[0];
578
+ float & vmax = trained[1];
579
+
580
+ if (rs == ScalarQuantizer::RS_minmax) {
581
+ vmin = HUGE_VAL; vmax = -HUGE_VAL;
582
+ for (size_t i = 0; i < n; i++) {
583
+ if (x[i] < vmin) vmin = x[i];
584
+ if (x[i] > vmax) vmax = x[i];
585
+ }
586
+ float vexp = (vmax - vmin) * rs_arg;
587
+ vmin -= vexp;
588
+ vmax += vexp;
589
+ } else if (rs == ScalarQuantizer::RS_meanstd) {
590
+ double sum = 0, sum2 = 0;
591
+ for (size_t i = 0; i < n; i++) {
592
+ sum += x[i];
593
+ sum2 += x[i] * x[i];
594
+ }
595
+ float mean = sum / n;
596
+ float var = sum2 / n - mean * mean;
597
+ float std = var <= 0 ? 1.0 : sqrt(var);
598
+
599
+ vmin = mean - std * rs_arg ;
600
+ vmax = mean + std * rs_arg ;
601
+ } else if (rs == ScalarQuantizer::RS_quantiles) {
602
+ std::vector<float> x_copy(n);
603
+ memcpy(x_copy.data(), x, n * sizeof(*x));
604
+ // TODO just do a qucikselect
605
+ std::sort(x_copy.begin(), x_copy.end());
606
+ int o = int(rs_arg * n);
607
+ if (o < 0) o = 0;
608
+ if (o > n - o) o = n / 2;
609
+ vmin = x_copy[o];
610
+ vmax = x_copy[n - 1 - o];
611
+
612
+ } else if (rs == ScalarQuantizer::RS_optim) {
613
+ float a, b;
614
+ float sx = 0;
615
+ {
616
+ vmin = HUGE_VAL, vmax = -HUGE_VAL;
617
+ for (size_t i = 0; i < n; i++) {
618
+ if (x[i] < vmin) vmin = x[i];
619
+ if (x[i] > vmax) vmax = x[i];
620
+ sx += x[i];
621
+ }
622
+ b = vmin;
623
+ a = (vmax - vmin) / (k - 1);
624
+ }
625
+ int verbose = false;
626
+ int niter = 2000;
627
+ float last_err = -1;
628
+ int iter_last_err = 0;
629
+ for (int it = 0; it < niter; it++) {
630
+ float sn = 0, sn2 = 0, sxn = 0, err1 = 0;
631
+
632
+ for (idx_t i = 0; i < n; i++) {
633
+ float xi = x[i];
634
+ float ni = floor ((xi - b) / a + 0.5);
635
+ if (ni < 0) ni = 0;
636
+ if (ni >= k) ni = k - 1;
637
+ err1 += sqr (xi - (ni * a + b));
638
+ sn += ni;
639
+ sn2 += ni * ni;
640
+ sxn += ni * xi;
641
+ }
642
+
643
+ if (err1 == last_err) {
644
+ iter_last_err ++;
645
+ if (iter_last_err == 16) break;
646
+ } else {
647
+ last_err = err1;
648
+ iter_last_err = 0;
649
+ }
650
+
651
+ float det = sqr (sn) - sn2 * n;
652
+
653
+ b = (sn * sxn - sn2 * sx) / det;
654
+ a = (sn * sx - n * sxn) / det;
655
+ if (verbose) {
656
+ printf ("it %d, err1=%g \r", it, err1);
657
+ fflush(stdout);
658
+ }
659
+ }
660
+ if (verbose) printf("\n");
661
+
662
+ vmin = b;
663
+ vmax = b + a * (k - 1);
664
+
665
+ } else {
666
+ FAISS_THROW_MSG ("Invalid qtype");
667
+ }
668
+ vmax -= vmin;
669
+ }
670
+
671
+ void train_NonUniform(RangeStat rs, float rs_arg,
672
+ idx_t n, int d, int k, const float *x,
673
+ std::vector<float> & trained)
674
+ {
675
+
676
+ trained.resize (2 * d);
677
+ float * vmin = trained.data();
678
+ float * vmax = trained.data() + d;
679
+ if (rs == ScalarQuantizer::RS_minmax) {
680
+ memcpy (vmin, x, sizeof(*x) * d);
681
+ memcpy (vmax, x, sizeof(*x) * d);
682
+ for (size_t i = 1; i < n; i++) {
683
+ const float *xi = x + i * d;
684
+ for (size_t j = 0; j < d; j++) {
685
+ if (xi[j] < vmin[j]) vmin[j] = xi[j];
686
+ if (xi[j] > vmax[j]) vmax[j] = xi[j];
687
+ }
688
+ }
689
+ float *vdiff = vmax;
690
+ for (size_t j = 0; j < d; j++) {
691
+ float vexp = (vmax[j] - vmin[j]) * rs_arg;
692
+ vmin[j] -= vexp;
693
+ vmax[j] += vexp;
694
+ vdiff [j] = vmax[j] - vmin[j];
695
+ }
696
+ } else {
697
+ // transpose
698
+ std::vector<float> xt(n * d);
699
+ for (size_t i = 1; i < n; i++) {
700
+ const float *xi = x + i * d;
701
+ for (size_t j = 0; j < d; j++) {
702
+ xt[j * n + i] = xi[j];
703
+ }
704
+ }
705
+ std::vector<float> trained_d(2);
706
+ #pragma omp parallel for
707
+ for (size_t j = 0; j < d; j++) {
708
+ train_Uniform(rs, rs_arg,
709
+ n, k, xt.data() + j * n,
710
+ trained_d);
711
+ vmin[j] = trained_d[0];
712
+ vmax[j] = trained_d[1];
713
+ }
714
+ }
715
+ }
716
+
717
+
718
+
719
+ /*******************************************************************
720
+ * Similarity: gets vector components and computes a similarity wrt. a
721
+ * query vector stored in the object. The data fields just encapsulate
722
+ * an accumulator.
723
+ */
724
+
725
+ template<int SIMDWIDTH>
726
+ struct SimilarityL2 {};
727
+
728
+
729
+ template<>
730
+ struct SimilarityL2<1> {
731
+ static constexpr int simdwidth = 1;
732
+ static constexpr MetricType metric_type = METRIC_L2;
733
+
734
+ const float *y, *yi;
735
+
736
+ explicit SimilarityL2 (const float * y): y(y) {}
737
+
738
+ /******* scalar accumulator *******/
739
+
740
+ float accu;
741
+
742
+ void begin () {
743
+ accu = 0;
744
+ yi = y;
745
+ }
746
+
747
+ void add_component (float x) {
748
+ float tmp = *yi++ - x;
749
+ accu += tmp * tmp;
750
+ }
751
+
752
+ void add_component_2 (float x1, float x2) {
753
+ float tmp = x1 - x2;
754
+ accu += tmp * tmp;
755
+ }
756
+
757
+ float result () {
758
+ return accu;
759
+ }
760
+ };
761
+
762
+
763
+ #ifdef USE_AVX
764
+ template<>
765
+ struct SimilarityL2<8> {
766
+ static constexpr int simdwidth = 8;
767
+ static constexpr MetricType metric_type = METRIC_L2;
768
+
769
+ const float *y, *yi;
770
+
771
+ explicit SimilarityL2 (const float * y): y(y) {}
772
+ __m256 accu8;
773
+
774
+ void begin_8 () {
775
+ accu8 = _mm256_setzero_ps();
776
+ yi = y;
777
+ }
778
+
779
+ void add_8_components (__m256 x) {
780
+ __m256 yiv = _mm256_loadu_ps (yi);
781
+ yi += 8;
782
+ __m256 tmp = yiv - x;
783
+ accu8 += tmp * tmp;
784
+ }
785
+
786
+ void add_8_components_2 (__m256 x, __m256 y) {
787
+ __m256 tmp = y - x;
788
+ accu8 += tmp * tmp;
789
+ }
790
+
791
+ float result_8 () {
792
+ __m256 sum = _mm256_hadd_ps(accu8, accu8);
793
+ __m256 sum2 = _mm256_hadd_ps(sum, sum);
794
+ // now add the 0th and 4th component
795
+ return
796
+ _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
797
+ _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
798
+ }
799
+
800
+ };
801
+
802
+ #endif
803
+
804
+
805
+ template<int SIMDWIDTH>
806
+ struct SimilarityIP {};
807
+
808
+
809
+ template<>
810
+ struct SimilarityIP<1> {
811
+ static constexpr int simdwidth = 1;
812
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
813
+ const float *y, *yi;
814
+
815
+ float accu;
816
+
817
+ explicit SimilarityIP (const float * y):
818
+ y (y) {}
819
+
820
+ void begin () {
821
+ accu = 0;
822
+ yi = y;
823
+ }
824
+
825
+ void add_component (float x) {
826
+ accu += *yi++ * x;
827
+ }
828
+
829
+ void add_component_2 (float x1, float x2) {
830
+ accu += x1 * x2;
831
+ }
832
+
833
+ float result () {
834
+ return accu;
835
+ }
836
+ };
837
+
838
+ #ifdef USE_AVX
839
+
840
+ template<>
841
+ struct SimilarityIP<8> {
842
+ static constexpr int simdwidth = 8;
843
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
844
+
845
+ const float *y, *yi;
846
+
847
+ float accu;
848
+
849
+ explicit SimilarityIP (const float * y):
850
+ y (y) {}
851
+
852
+ __m256 accu8;
853
+
854
+ void begin_8 () {
855
+ accu8 = _mm256_setzero_ps();
856
+ yi = y;
857
+ }
858
+
859
+ void add_8_components (__m256 x) {
860
+ __m256 yiv = _mm256_loadu_ps (yi);
861
+ yi += 8;
862
+ accu8 += yiv * x;
863
+ }
864
+
865
+ void add_8_components_2 (__m256 x1, __m256 x2) {
866
+ accu8 += x1 * x2;
867
+ }
868
+
869
+ float result_8 () {
870
+ __m256 sum = _mm256_hadd_ps(accu8, accu8);
871
+ __m256 sum2 = _mm256_hadd_ps(sum, sum);
872
+ // now add the 0th and 4th component
873
+ return
874
+ _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
875
+ _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
876
+ }
877
+ };
878
+ #endif
879
+
880
+
881
+ /*******************************************************************
882
+ * DistanceComputer: combines a similarity and a quantizer to do
883
+ * code-to-vector or code-to-code comparisons
884
+ *******************************************************************/
885
+
886
+ template<class Quantizer, class Similarity, int SIMDWIDTH>
887
+ struct DCTemplate : SQDistanceComputer {};
888
+
889
+ template<class Quantizer, class Similarity>
890
+ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
891
+ {
892
+ using Sim = Similarity;
893
+
894
+ Quantizer quant;
895
+
896
+ DCTemplate(size_t d, const std::vector<float> &trained):
897
+ quant(d, trained)
898
+ {}
899
+
900
+ float compute_distance(const float* x, const uint8_t* code) const {
901
+
902
+ Similarity sim(x);
903
+ sim.begin();
904
+ for (size_t i = 0; i < quant.d; i++) {
905
+ float xi = quant.reconstruct_component(code, i);
906
+ sim.add_component(xi);
907
+ }
908
+ return sim.result();
909
+ }
910
+
911
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
912
+ const {
913
+ Similarity sim(nullptr);
914
+ sim.begin();
915
+ for (size_t i = 0; i < quant.d; i++) {
916
+ float x1 = quant.reconstruct_component(code1, i);
917
+ float x2 = quant.reconstruct_component(code2, i);
918
+ sim.add_component_2(x1, x2);
919
+ }
920
+ return sim.result();
921
+ }
922
+
923
+ void set_query (const float *x) final {
924
+ q = x;
925
+ }
926
+
927
+ /// compute distance of vector i to current query
928
+ float operator () (idx_t i) final {
929
+ return compute_distance (q, codes + i * code_size);
930
+ }
931
+
932
+ float symmetric_dis (idx_t i, idx_t j) override {
933
+ return compute_code_distance (codes + i * code_size,
934
+ codes + j * code_size);
935
+ }
936
+
937
+ float query_to_code (const uint8_t * code) const {
938
+ return compute_distance (q, code);
939
+ }
940
+
941
+ };
942
+
943
+ #ifdef USE_F16C
944
+
945
+ template<class Quantizer, class Similarity>
946
+ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
947
+ {
948
+ using Sim = Similarity;
949
+
950
+ Quantizer quant;
951
+
952
+ DCTemplate(size_t d, const std::vector<float> &trained):
953
+ quant(d, trained)
954
+ {}
955
+
956
+ float compute_distance(const float* x, const uint8_t* code) const {
957
+
958
+ Similarity sim(x);
959
+ sim.begin_8();
960
+ for (size_t i = 0; i < quant.d; i += 8) {
961
+ __m256 xi = quant.reconstruct_8_components(code, i);
962
+ sim.add_8_components(xi);
963
+ }
964
+ return sim.result_8();
965
+ }
966
+
967
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
968
+ const {
969
+ Similarity sim(nullptr);
970
+ sim.begin_8();
971
+ for (size_t i = 0; i < quant.d; i += 8) {
972
+ __m256 x1 = quant.reconstruct_8_components(code1, i);
973
+ __m256 x2 = quant.reconstruct_8_components(code2, i);
974
+ sim.add_8_components_2(x1, x2);
975
+ }
976
+ return sim.result_8();
977
+ }
978
+
979
+ void set_query (const float *x) final {
980
+ q = x;
981
+ }
982
+
983
+ /// compute distance of vector i to current query
984
+ float operator () (idx_t i) final {
985
+ return compute_distance (q, codes + i * code_size);
986
+ }
987
+
988
+ float symmetric_dis (idx_t i, idx_t j) override {
989
+ return compute_code_distance (codes + i * code_size,
990
+ codes + j * code_size);
991
+ }
992
+
993
+ float query_to_code (const uint8_t * code) const {
994
+ return compute_distance (q, code);
995
+ }
996
+
997
+ };
998
+
999
+ #endif
1000
+
1001
+
1002
+
1003
+ /*******************************************************************
1004
+ * DistanceComputerByte: computes distances in the integer domain
1005
+ *******************************************************************/
1006
+
1007
+ template<class Similarity, int SIMDWIDTH>
1008
+ struct DistanceComputerByte : SQDistanceComputer {};
1009
+
1010
+ template<class Similarity>
1011
+ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1012
+ using Sim = Similarity;
1013
+
1014
+ int d;
1015
+ std::vector<uint8_t> tmp;
1016
+
1017
+ DistanceComputerByte(int d, const std::vector<float> &): d(d), tmp(d) {
1018
+ }
1019
+
1020
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1021
+ const {
1022
+ int accu = 0;
1023
+ for (int i = 0; i < d; i++) {
1024
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1025
+ accu += int(code1[i]) * code2[i];
1026
+ } else {
1027
+ int diff = int(code1[i]) - code2[i];
1028
+ accu += diff * diff;
1029
+ }
1030
+ }
1031
+ return accu;
1032
+ }
1033
+
1034
+ void set_query (const float *x) final {
1035
+ for (int i = 0; i < d; i++) {
1036
+ tmp[i] = int(x[i]);
1037
+ }
1038
+ }
1039
+
1040
+ int compute_distance(const float* x, const uint8_t* code) {
1041
+ set_query(x);
1042
+ return compute_code_distance(tmp.data(), code);
1043
+ }
1044
+
1045
+ /// compute distance of vector i to current query
1046
+ float operator () (idx_t i) final {
1047
+ return compute_distance (q, codes + i * code_size);
1048
+ }
1049
+
1050
+ float symmetric_dis (idx_t i, idx_t j) override {
1051
+ return compute_code_distance (codes + i * code_size,
1052
+ codes + j * code_size);
1053
+ }
1054
+
1055
+ float query_to_code (const uint8_t * code) const {
1056
+ return compute_code_distance (tmp.data(), code);
1057
+ }
1058
+
1059
+ };
1060
+
1061
+ #ifdef USE_AVX
1062
+
1063
+
1064
+ template<class Similarity>
1065
+ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1066
+ using Sim = Similarity;
1067
+
1068
+ int d;
1069
+ std::vector<uint8_t> tmp;
1070
+
1071
+ DistanceComputerByte(int d, const std::vector<float> &): d(d), tmp(d) {
1072
+ }
1073
+
1074
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1075
+ const {
1076
+ // __m256i accu = _mm256_setzero_ps ();
1077
+ __m256i accu = _mm256_setzero_si256 ();
1078
+ for (int i = 0; i < d; i += 16) {
1079
+ // load 16 bytes, convert to 16 uint16_t
1080
+ __m256i c1 = _mm256_cvtepu8_epi16
1081
+ (_mm_loadu_si128((__m128i*)(code1 + i)));
1082
+ __m256i c2 = _mm256_cvtepu8_epi16
1083
+ (_mm_loadu_si128((__m128i*)(code2 + i)));
1084
+ __m256i prod32;
1085
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1086
+ prod32 = _mm256_madd_epi16(c1, c2);
1087
+ } else {
1088
+ __m256i diff = _mm256_sub_epi16(c1, c2);
1089
+ prod32 = _mm256_madd_epi16(diff, diff);
1090
+ }
1091
+ accu = _mm256_add_epi32 (accu, prod32);
1092
+
1093
+ }
1094
+ __m128i sum = _mm256_extractf128_si256(accu, 0);
1095
+ sum = _mm_add_epi32 (sum, _mm256_extractf128_si256(accu, 1));
1096
+ sum = _mm_hadd_epi32 (sum, sum);
1097
+ sum = _mm_hadd_epi32 (sum, sum);
1098
+ return _mm_cvtsi128_si32 (sum);
1099
+ }
1100
+
1101
+ void set_query (const float *x) final {
1102
+ /*
1103
+ for (int i = 0; i < d; i += 8) {
1104
+ __m256 xi = _mm256_loadu_ps (x + i);
1105
+ __m256i ci = _mm256_cvtps_epi32(xi);
1106
+ */
1107
+ for (int i = 0; i < d; i++) {
1108
+ tmp[i] = int(x[i]);
1109
+ }
1110
+ }
1111
+
1112
+ int compute_distance(const float* x, const uint8_t* code) {
1113
+ set_query(x);
1114
+ return compute_code_distance(tmp.data(), code);
1115
+ }
1116
+
1117
+ /// compute distance of vector i to current query
1118
+ float operator () (idx_t i) final {
1119
+ return compute_distance (q, codes + i * code_size);
1120
+ }
1121
+
1122
+ float symmetric_dis (idx_t i, idx_t j) override {
1123
+ return compute_code_distance (codes + i * code_size,
1124
+ codes + j * code_size);
1125
+ }
1126
+
1127
+ float query_to_code (const uint8_t * code) const {
1128
+ return compute_code_distance (tmp.data(), code);
1129
+ }
1130
+
1131
+
1132
+ };
1133
+
1134
+ #endif
1135
+
1136
+ /*******************************************************************
1137
+ * select_distance_computer: runtime selection of template
1138
+ * specialization
1139
+ *******************************************************************/
1140
+
1141
+
1142
+ template<class Sim>
1143
+ SQDistanceComputer *select_distance_computer (
1144
+ QuantizerType qtype,
1145
+ size_t d, const std::vector<float> & trained)
1146
+ {
1147
+ constexpr int SIMDWIDTH = Sim::simdwidth;
1148
+ switch(qtype) {
1149
+ case ScalarQuantizer::QT_8bit_uniform:
1150
+ return new DCTemplate<QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1151
+ Sim, SIMDWIDTH>(d, trained);
1152
+
1153
+ case ScalarQuantizer::QT_4bit_uniform:
1154
+ return new DCTemplate<QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1155
+ Sim, SIMDWIDTH>(d, trained);
1156
+
1157
+ case ScalarQuantizer::QT_8bit:
1158
+ return new DCTemplate<QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1159
+ Sim, SIMDWIDTH>(d, trained);
1160
+
1161
+ case ScalarQuantizer::QT_6bit:
1162
+ return new DCTemplate<QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1163
+ Sim, SIMDWIDTH>(d, trained);
1164
+
1165
+ case ScalarQuantizer::QT_4bit:
1166
+ return new DCTemplate<QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1167
+ Sim, SIMDWIDTH>(d, trained);
1168
+
1169
+ case ScalarQuantizer::QT_fp16:
1170
+ return new DCTemplate
1171
+ <QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained);
1172
+
1173
+ case ScalarQuantizer::QT_8bit_direct:
1174
+ if (d % 16 == 0) {
1175
+ return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1176
+ } else {
1177
+ return new DCTemplate
1178
+ <Quantizer8bitDirect<SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained);
1179
+ }
1180
+ }
1181
+ FAISS_THROW_MSG ("unknown qtype");
1182
+ return nullptr;
1183
+ }
1184
+
1185
+
1186
+
1187
+ } // anonymous namespace
1188
+
1189
+
1190
+
1191
+ /*******************************************************************
1192
+ * ScalarQuantizer implementation
1193
+ ********************************************************************/
1194
+
1195
+
1196
+
1197
+ ScalarQuantizer::ScalarQuantizer
1198
+ (size_t d, QuantizerType qtype):
1199
+ qtype (qtype), rangestat(RS_minmax), rangestat_arg(0), d (d)
1200
+ {
1201
+ switch (qtype) {
1202
+ case QT_8bit:
1203
+ case QT_8bit_uniform:
1204
+ case QT_8bit_direct:
1205
+ code_size = d;
1206
+ break;
1207
+ case QT_4bit:
1208
+ case QT_4bit_uniform:
1209
+ code_size = (d + 1) / 2;
1210
+ break;
1211
+ case QT_6bit:
1212
+ code_size = (d * 6 + 7) / 8;
1213
+ break;
1214
+ case QT_fp16:
1215
+ code_size = d * 2;
1216
+ break;
1217
+ }
1218
+
1219
+ }
1220
+
1221
+ ScalarQuantizer::ScalarQuantizer ():
1222
+ qtype(QT_8bit),
1223
+ rangestat(RS_minmax), rangestat_arg(0), d (0), code_size(0)
1224
+ {}
1225
+
1226
+ void ScalarQuantizer::train (size_t n, const float *x)
1227
+ {
1228
+ int bit_per_dim =
1229
+ qtype == QT_4bit_uniform ? 4 :
1230
+ qtype == QT_4bit ? 4 :
1231
+ qtype == QT_6bit ? 6 :
1232
+ qtype == QT_8bit_uniform ? 8 :
1233
+ qtype == QT_8bit ? 8 : -1;
1234
+
1235
+ switch (qtype) {
1236
+ case QT_4bit_uniform: case QT_8bit_uniform:
1237
+ train_Uniform (rangestat, rangestat_arg,
1238
+ n * d, 1 << bit_per_dim, x, trained);
1239
+ break;
1240
+ case QT_4bit: case QT_8bit: case QT_6bit:
1241
+ train_NonUniform (rangestat, rangestat_arg,
1242
+ n, d, 1 << bit_per_dim, x, trained);
1243
+ break;
1244
+ case QT_fp16:
1245
+ case QT_8bit_direct:
1246
+ // no training necessary
1247
+ break;
1248
+ }
1249
+ }
1250
+
1251
+ void ScalarQuantizer::train_residual(size_t n,
1252
+ const float *x,
1253
+ Index *quantizer,
1254
+ bool by_residual,
1255
+ bool verbose)
1256
+ {
1257
+ const float * x_in = x;
1258
+
1259
+ // 100k points more than enough
1260
+ x = fvecs_maybe_subsample (
1261
+ d, (size_t*)&n, 100000,
1262
+ x, verbose, 1234);
1263
+
1264
+ ScopeDeleter<float> del_x (x_in == x ? nullptr : x);
1265
+
1266
+ if (by_residual) {
1267
+ std::vector<Index::idx_t> idx(n);
1268
+ quantizer->assign (n, x, idx.data());
1269
+
1270
+ std::vector<float> residuals(n * d);
1271
+ quantizer->compute_residual_n (n, x, residuals.data(), idx.data());
1272
+
1273
+ train (n, residuals.data());
1274
+ } else {
1275
+ train (n, x);
1276
+ }
1277
+ }
1278
+
1279
+
1280
+ ScalarQuantizer::Quantizer *ScalarQuantizer::select_quantizer () const
1281
+ {
1282
+ #ifdef USE_F16C
1283
+ if (d % 8 == 0) {
1284
+ return select_quantizer_1<8> (qtype, d, trained);
1285
+ } else
1286
+ #endif
1287
+ {
1288
+ return select_quantizer_1<1> (qtype, d, trained);
1289
+ }
1290
+ }
1291
+
1292
+
1293
+ void ScalarQuantizer::compute_codes (const float * x,
1294
+ uint8_t * codes,
1295
+ size_t n) const
1296
+ {
1297
+ std::unique_ptr<Quantizer> squant(select_quantizer ());
1298
+
1299
+ memset (codes, 0, code_size * n);
1300
+ #pragma omp parallel for
1301
+ for (size_t i = 0; i < n; i++)
1302
+ squant->encode_vector (x + i * d, codes + i * code_size);
1303
+ }
1304
+
1305
+ void ScalarQuantizer::decode (const uint8_t *codes, float *x, size_t n) const
1306
+ {
1307
+ std::unique_ptr<Quantizer> squant(select_quantizer ());
1308
+
1309
+ #pragma omp parallel for
1310
+ for (size_t i = 0; i < n; i++)
1311
+ squant->decode_vector (codes + i * code_size, x + i * d);
1312
+ }
1313
+
1314
+
1315
+ SQDistanceComputer *
1316
+ ScalarQuantizer::get_distance_computer (MetricType metric) const
1317
+ {
1318
+ FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1319
+ #ifdef USE_F16C
1320
+ if (d % 8 == 0) {
1321
+ if (metric == METRIC_L2) {
1322
+ return select_distance_computer<SimilarityL2<8> >
1323
+ (qtype, d, trained);
1324
+ } else {
1325
+ return select_distance_computer<SimilarityIP<8> >
1326
+ (qtype, d, trained);
1327
+ }
1328
+ } else
1329
+ #endif
1330
+ {
1331
+ if (metric == METRIC_L2) {
1332
+ return select_distance_computer<SimilarityL2<1> >
1333
+ (qtype, d, trained);
1334
+ } else {
1335
+ return select_distance_computer<SimilarityIP<1> >
1336
+ (qtype, d, trained);
1337
+ }
1338
+ }
1339
+ }
1340
+
1341
+
1342
+ /*******************************************************************
1343
+ * IndexScalarQuantizer/IndexIVFScalarQuantizer scanner object
1344
+ *
1345
+ * It is an InvertedListScanner, but is designed to work with
1346
+ * IndexScalarQuantizer as well.
1347
+ ********************************************************************/
1348
+
1349
+ namespace {
1350
+
1351
+
1352
+ template<class DCClass>
1353
+ struct IVFSQScannerIP: InvertedListScanner {
1354
+ DCClass dc;
1355
+ bool store_pairs, by_residual;
1356
+
1357
+ size_t code_size;
1358
+
1359
+ idx_t list_no; /// current list (set to 0 for Flat index
1360
+ float accu0; /// added to all distances
1361
+
1362
+ IVFSQScannerIP(int d, const std::vector<float> & trained,
1363
+ size_t code_size, bool store_pairs,
1364
+ bool by_residual):
1365
+ dc(d, trained), store_pairs(store_pairs),
1366
+ by_residual(by_residual),
1367
+ code_size(code_size), list_no(0), accu0(0)
1368
+ {}
1369
+
1370
+
1371
+ void set_query (const float *query) override {
1372
+ dc.set_query (query);
1373
+ }
1374
+
1375
+ void set_list (idx_t list_no, float coarse_dis) override {
1376
+ this->list_no = list_no;
1377
+ accu0 = by_residual ? coarse_dis : 0;
1378
+ }
1379
+
1380
+ float distance_to_code (const uint8_t *code) const final {
1381
+ return accu0 + dc.query_to_code (code);
1382
+ }
1383
+
1384
+ size_t scan_codes (size_t list_size,
1385
+ const uint8_t *codes,
1386
+ const idx_t *ids,
1387
+ float *simi, idx_t *idxi,
1388
+ size_t k) const override
1389
+ {
1390
+ size_t nup = 0;
1391
+
1392
+ for (size_t j = 0; j < list_size; j++) {
1393
+
1394
+ float accu = accu0 + dc.query_to_code (codes);
1395
+
1396
+ if (accu > simi [0]) {
1397
+ minheap_pop (k, simi, idxi);
1398
+ int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1399
+ minheap_push (k, simi, idxi, accu, id);
1400
+ nup++;
1401
+ }
1402
+ codes += code_size;
1403
+ }
1404
+ return nup;
1405
+ }
1406
+
1407
+ void scan_codes_range (size_t list_size,
1408
+ const uint8_t *codes,
1409
+ const idx_t *ids,
1410
+ float radius,
1411
+ RangeQueryResult & res) const override
1412
+ {
1413
+ for (size_t j = 0; j < list_size; j++) {
1414
+ float accu = accu0 + dc.query_to_code (codes);
1415
+ if (accu > radius) {
1416
+ int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1417
+ res.add (accu, id);
1418
+ }
1419
+ codes += code_size;
1420
+ }
1421
+ }
1422
+
1423
+
1424
+ };
1425
+
1426
+
1427
+ template<class DCClass>
1428
+ struct IVFSQScannerL2: InvertedListScanner {
1429
+
1430
+ DCClass dc;
1431
+
1432
+ bool store_pairs, by_residual;
1433
+ size_t code_size;
1434
+ const Index *quantizer;
1435
+ idx_t list_no; /// current inverted list
1436
+ const float *x; /// current query
1437
+
1438
+ std::vector<float> tmp;
1439
+
1440
+ IVFSQScannerL2(int d, const std::vector<float> & trained,
1441
+ size_t code_size, const Index *quantizer,
1442
+ bool store_pairs, bool by_residual):
1443
+ dc(d, trained), store_pairs(store_pairs), by_residual(by_residual),
1444
+ code_size(code_size), quantizer(quantizer),
1445
+ list_no (0), x (nullptr), tmp (d)
1446
+ {
1447
+ }
1448
+
1449
+
1450
+ void set_query (const float *query) override {
1451
+ x = query;
1452
+ if (!quantizer) {
1453
+ dc.set_query (query);
1454
+ }
1455
+ }
1456
+
1457
+
1458
+ void set_list (idx_t list_no, float /*coarse_dis*/) override {
1459
+ if (by_residual) {
1460
+ this->list_no = list_no;
1461
+ // shift of x_in wrt centroid
1462
+ quantizer->compute_residual (x, tmp.data(), list_no);
1463
+ dc.set_query (tmp.data ());
1464
+ } else {
1465
+ dc.set_query (x);
1466
+ }
1467
+ }
1468
+
1469
+ float distance_to_code (const uint8_t *code) const final {
1470
+ return dc.query_to_code (code);
1471
+ }
1472
+
1473
+ size_t scan_codes (size_t list_size,
1474
+ const uint8_t *codes,
1475
+ const idx_t *ids,
1476
+ float *simi, idx_t *idxi,
1477
+ size_t k) const override
1478
+ {
1479
+ size_t nup = 0;
1480
+ for (size_t j = 0; j < list_size; j++) {
1481
+
1482
+ float dis = dc.query_to_code (codes);
1483
+
1484
+ if (dis < simi [0]) {
1485
+ maxheap_pop (k, simi, idxi);
1486
+ int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1487
+ maxheap_push (k, simi, idxi, dis, id);
1488
+ nup++;
1489
+ }
1490
+ codes += code_size;
1491
+ }
1492
+ return nup;
1493
+ }
1494
+
1495
+ void scan_codes_range (size_t list_size,
1496
+ const uint8_t *codes,
1497
+ const idx_t *ids,
1498
+ float radius,
1499
+ RangeQueryResult & res) const override
1500
+ {
1501
+ for (size_t j = 0; j < list_size; j++) {
1502
+ float dis = dc.query_to_code (codes);
1503
+ if (dis < radius) {
1504
+ int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1505
+ res.add (dis, id);
1506
+ }
1507
+ codes += code_size;
1508
+ }
1509
+ }
1510
+
1511
+
1512
+ };
1513
+
1514
+ template<class DCClass>
1515
+ InvertedListScanner* sel2_InvertedListScanner
1516
+ (const ScalarQuantizer *sq,
1517
+ const Index *quantizer, bool store_pairs, bool r)
1518
+ {
1519
+ if (DCClass::Sim::metric_type == METRIC_L2) {
1520
+ return new IVFSQScannerL2<DCClass>(sq->d, sq->trained, sq->code_size,
1521
+ quantizer, store_pairs, r);
1522
+ } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
1523
+ return new IVFSQScannerIP<DCClass>(sq->d, sq->trained, sq->code_size,
1524
+ store_pairs, r);
1525
+ } else {
1526
+ FAISS_THROW_MSG("unsupported metric type");
1527
+ }
1528
+ }
1529
+
1530
+ template<class Similarity, class Codec, bool uniform>
1531
+ InvertedListScanner* sel12_InvertedListScanner
1532
+ (const ScalarQuantizer *sq,
1533
+ const Index *quantizer, bool store_pairs, bool r)
1534
+ {
1535
+ constexpr int SIMDWIDTH = Similarity::simdwidth;
1536
+ using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
1537
+ using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
1538
+ return sel2_InvertedListScanner<DCClass> (sq, quantizer, store_pairs, r);
1539
+ }
1540
+
1541
+
1542
+
1543
+ template<class Similarity>
1544
+ InvertedListScanner* sel1_InvertedListScanner
1545
+ (const ScalarQuantizer *sq, const Index *quantizer,
1546
+ bool store_pairs, bool r)
1547
+ {
1548
+ constexpr int SIMDWIDTH = Similarity::simdwidth;
1549
+ switch(sq->qtype) {
1550
+ case ScalarQuantizer::QT_8bit_uniform:
1551
+ return sel12_InvertedListScanner
1552
+ <Similarity, Codec8bit, true>(sq, quantizer, store_pairs, r);
1553
+ case ScalarQuantizer::QT_4bit_uniform:
1554
+ return sel12_InvertedListScanner
1555
+ <Similarity, Codec4bit, true>(sq, quantizer, store_pairs, r);
1556
+ case ScalarQuantizer::QT_8bit:
1557
+ return sel12_InvertedListScanner
1558
+ <Similarity, Codec8bit, false>(sq, quantizer, store_pairs, r);
1559
+ case ScalarQuantizer::QT_4bit:
1560
+ return sel12_InvertedListScanner
1561
+ <Similarity, Codec4bit, false>(sq, quantizer, store_pairs, r);
1562
+ case ScalarQuantizer::QT_6bit:
1563
+ return sel12_InvertedListScanner
1564
+ <Similarity, Codec6bit, false>(sq, quantizer, store_pairs, r);
1565
+ case ScalarQuantizer::QT_fp16:
1566
+ return sel2_InvertedListScanner
1567
+ <DCTemplate<QuantizerFP16<SIMDWIDTH>, Similarity, SIMDWIDTH> >
1568
+ (sq, quantizer, store_pairs, r);
1569
+ case ScalarQuantizer::QT_8bit_direct:
1570
+ if (sq->d % 16 == 0) {
1571
+ return sel2_InvertedListScanner
1572
+ <DistanceComputerByte<Similarity, SIMDWIDTH> >
1573
+ (sq, quantizer, store_pairs, r);
1574
+ } else {
1575
+ return sel2_InvertedListScanner
1576
+ <DCTemplate<Quantizer8bitDirect<SIMDWIDTH>,
1577
+ Similarity, SIMDWIDTH> >
1578
+ (sq, quantizer, store_pairs, r);
1579
+ }
1580
+
1581
+ }
1582
+
1583
+ FAISS_THROW_MSG ("unknown qtype");
1584
+ return nullptr;
1585
+ }
1586
+
1587
+ template<int SIMDWIDTH>
1588
+ InvertedListScanner* sel0_InvertedListScanner
1589
+ (MetricType mt, const ScalarQuantizer *sq,
1590
+ const Index *quantizer, bool store_pairs, bool by_residual)
1591
+ {
1592
+ if (mt == METRIC_L2) {
1593
+ return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH> >
1594
+ (sq, quantizer, store_pairs, by_residual);
1595
+ } else if (mt == METRIC_INNER_PRODUCT) {
1596
+ return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH> >
1597
+ (sq, quantizer, store_pairs, by_residual);
1598
+ } else {
1599
+ FAISS_THROW_MSG("unsupported metric type");
1600
+ }
1601
+ }
1602
+
1603
+
1604
+
1605
+ } // anonymous namespace
1606
+
1607
+
1608
+ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner
1609
+ (MetricType mt, const Index *quantizer,
1610
+ bool store_pairs, bool by_residual) const
1611
+ {
1612
+ #ifdef USE_F16C
1613
+ if (d % 8 == 0) {
1614
+ return sel0_InvertedListScanner<8>
1615
+ (mt, this, quantizer, store_pairs, by_residual);
1616
+ } else
1617
+ #endif
1618
+ {
1619
+ return sel0_InvertedListScanner<1>
1620
+ (mt, this, quantizer, store_pairs, by_residual);
1621
+ }
1622
+ }
1623
+
1624
+
1625
+
1626
+
1627
+
1628
+ } // namespace faiss