faiss 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (226) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +103 -3
  4. data/ext/faiss/ext.cpp +99 -32
  5. data/ext/faiss/extconf.rb +12 -2
  6. data/lib/faiss/ext.bundle +0 -0
  7. data/lib/faiss/index.rb +3 -3
  8. data/lib/faiss/index_binary.rb +3 -3
  9. data/lib/faiss/kmeans.rb +1 -1
  10. data/lib/faiss/pca_matrix.rb +2 -2
  11. data/lib/faiss/product_quantizer.rb +3 -3
  12. data/lib/faiss/version.rb +1 -1
  13. data/vendor/faiss/AutoTune.cpp +719 -0
  14. data/vendor/faiss/AutoTune.h +212 -0
  15. data/vendor/faiss/Clustering.cpp +261 -0
  16. data/vendor/faiss/Clustering.h +101 -0
  17. data/vendor/faiss/IVFlib.cpp +339 -0
  18. data/vendor/faiss/IVFlib.h +132 -0
  19. data/vendor/faiss/Index.cpp +171 -0
  20. data/vendor/faiss/Index.h +261 -0
  21. data/vendor/faiss/Index2Layer.cpp +437 -0
  22. data/vendor/faiss/Index2Layer.h +85 -0
  23. data/vendor/faiss/IndexBinary.cpp +77 -0
  24. data/vendor/faiss/IndexBinary.h +163 -0
  25. data/vendor/faiss/IndexBinaryFlat.cpp +83 -0
  26. data/vendor/faiss/IndexBinaryFlat.h +54 -0
  27. data/vendor/faiss/IndexBinaryFromFloat.cpp +78 -0
  28. data/vendor/faiss/IndexBinaryFromFloat.h +52 -0
  29. data/vendor/faiss/IndexBinaryHNSW.cpp +325 -0
  30. data/vendor/faiss/IndexBinaryHNSW.h +56 -0
  31. data/vendor/faiss/IndexBinaryIVF.cpp +671 -0
  32. data/vendor/faiss/IndexBinaryIVF.h +211 -0
  33. data/vendor/faiss/IndexFlat.cpp +508 -0
  34. data/vendor/faiss/IndexFlat.h +175 -0
  35. data/vendor/faiss/IndexHNSW.cpp +1090 -0
  36. data/vendor/faiss/IndexHNSW.h +170 -0
  37. data/vendor/faiss/IndexIVF.cpp +909 -0
  38. data/vendor/faiss/IndexIVF.h +353 -0
  39. data/vendor/faiss/IndexIVFFlat.cpp +502 -0
  40. data/vendor/faiss/IndexIVFFlat.h +118 -0
  41. data/vendor/faiss/IndexIVFPQ.cpp +1207 -0
  42. data/vendor/faiss/IndexIVFPQ.h +161 -0
  43. data/vendor/faiss/IndexIVFPQR.cpp +219 -0
  44. data/vendor/faiss/IndexIVFPQR.h +65 -0
  45. data/vendor/faiss/IndexIVFSpectralHash.cpp +331 -0
  46. data/vendor/faiss/IndexIVFSpectralHash.h +75 -0
  47. data/vendor/faiss/IndexLSH.cpp +225 -0
  48. data/vendor/faiss/IndexLSH.h +87 -0
  49. data/vendor/faiss/IndexLattice.cpp +143 -0
  50. data/vendor/faiss/IndexLattice.h +68 -0
  51. data/vendor/faiss/IndexPQ.cpp +1188 -0
  52. data/vendor/faiss/IndexPQ.h +199 -0
  53. data/vendor/faiss/IndexPreTransform.cpp +288 -0
  54. data/vendor/faiss/IndexPreTransform.h +91 -0
  55. data/vendor/faiss/IndexReplicas.cpp +123 -0
  56. data/vendor/faiss/IndexReplicas.h +76 -0
  57. data/vendor/faiss/IndexScalarQuantizer.cpp +317 -0
  58. data/vendor/faiss/IndexScalarQuantizer.h +127 -0
  59. data/vendor/faiss/IndexShards.cpp +317 -0
  60. data/vendor/faiss/IndexShards.h +100 -0
  61. data/vendor/faiss/InvertedLists.cpp +623 -0
  62. data/vendor/faiss/InvertedLists.h +334 -0
  63. data/vendor/faiss/LICENSE +21 -0
  64. data/vendor/faiss/MatrixStats.cpp +252 -0
  65. data/vendor/faiss/MatrixStats.h +62 -0
  66. data/vendor/faiss/MetaIndexes.cpp +351 -0
  67. data/vendor/faiss/MetaIndexes.h +126 -0
  68. data/vendor/faiss/OnDiskInvertedLists.cpp +674 -0
  69. data/vendor/faiss/OnDiskInvertedLists.h +127 -0
  70. data/vendor/faiss/VectorTransform.cpp +1157 -0
  71. data/vendor/faiss/VectorTransform.h +322 -0
  72. data/vendor/faiss/c_api/AutoTune_c.cpp +83 -0
  73. data/vendor/faiss/c_api/AutoTune_c.h +64 -0
  74. data/vendor/faiss/c_api/Clustering_c.cpp +139 -0
  75. data/vendor/faiss/c_api/Clustering_c.h +117 -0
  76. data/vendor/faiss/c_api/IndexFlat_c.cpp +140 -0
  77. data/vendor/faiss/c_api/IndexFlat_c.h +115 -0
  78. data/vendor/faiss/c_api/IndexIVFFlat_c.cpp +64 -0
  79. data/vendor/faiss/c_api/IndexIVFFlat_c.h +58 -0
  80. data/vendor/faiss/c_api/IndexIVF_c.cpp +92 -0
  81. data/vendor/faiss/c_api/IndexIVF_c.h +135 -0
  82. data/vendor/faiss/c_api/IndexLSH_c.cpp +37 -0
  83. data/vendor/faiss/c_api/IndexLSH_c.h +40 -0
  84. data/vendor/faiss/c_api/IndexShards_c.cpp +44 -0
  85. data/vendor/faiss/c_api/IndexShards_c.h +42 -0
  86. data/vendor/faiss/c_api/Index_c.cpp +105 -0
  87. data/vendor/faiss/c_api/Index_c.h +183 -0
  88. data/vendor/faiss/c_api/MetaIndexes_c.cpp +49 -0
  89. data/vendor/faiss/c_api/MetaIndexes_c.h +49 -0
  90. data/vendor/faiss/c_api/clone_index_c.cpp +23 -0
  91. data/vendor/faiss/c_api/clone_index_c.h +32 -0
  92. data/vendor/faiss/c_api/error_c.h +42 -0
  93. data/vendor/faiss/c_api/error_impl.cpp +27 -0
  94. data/vendor/faiss/c_api/error_impl.h +16 -0
  95. data/vendor/faiss/c_api/faiss_c.h +58 -0
  96. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.cpp +96 -0
  97. data/vendor/faiss/c_api/gpu/GpuAutoTune_c.h +56 -0
  98. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.cpp +52 -0
  99. data/vendor/faiss/c_api/gpu/GpuClonerOptions_c.h +68 -0
  100. data/vendor/faiss/c_api/gpu/GpuIndex_c.cpp +17 -0
  101. data/vendor/faiss/c_api/gpu/GpuIndex_c.h +30 -0
  102. data/vendor/faiss/c_api/gpu/GpuIndicesOptions_c.h +38 -0
  103. data/vendor/faiss/c_api/gpu/GpuResources_c.cpp +86 -0
  104. data/vendor/faiss/c_api/gpu/GpuResources_c.h +66 -0
  105. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.cpp +54 -0
  106. data/vendor/faiss/c_api/gpu/StandardGpuResources_c.h +53 -0
  107. data/vendor/faiss/c_api/gpu/macros_impl.h +42 -0
  108. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.cpp +220 -0
  109. data/vendor/faiss/c_api/impl/AuxIndexStructures_c.h +149 -0
  110. data/vendor/faiss/c_api/index_factory_c.cpp +26 -0
  111. data/vendor/faiss/c_api/index_factory_c.h +30 -0
  112. data/vendor/faiss/c_api/index_io_c.cpp +42 -0
  113. data/vendor/faiss/c_api/index_io_c.h +50 -0
  114. data/vendor/faiss/c_api/macros_impl.h +110 -0
  115. data/vendor/faiss/clone_index.cpp +147 -0
  116. data/vendor/faiss/clone_index.h +38 -0
  117. data/vendor/faiss/demos/demo_imi_flat.cpp +151 -0
  118. data/vendor/faiss/demos/demo_imi_pq.cpp +199 -0
  119. data/vendor/faiss/demos/demo_ivfpq_indexing.cpp +146 -0
  120. data/vendor/faiss/demos/demo_sift1M.cpp +252 -0
  121. data/vendor/faiss/gpu/GpuAutoTune.cpp +95 -0
  122. data/vendor/faiss/gpu/GpuAutoTune.h +27 -0
  123. data/vendor/faiss/gpu/GpuCloner.cpp +403 -0
  124. data/vendor/faiss/gpu/GpuCloner.h +82 -0
  125. data/vendor/faiss/gpu/GpuClonerOptions.cpp +28 -0
  126. data/vendor/faiss/gpu/GpuClonerOptions.h +53 -0
  127. data/vendor/faiss/gpu/GpuDistance.h +52 -0
  128. data/vendor/faiss/gpu/GpuFaissAssert.h +29 -0
  129. data/vendor/faiss/gpu/GpuIndex.h +148 -0
  130. data/vendor/faiss/gpu/GpuIndexBinaryFlat.h +89 -0
  131. data/vendor/faiss/gpu/GpuIndexFlat.h +190 -0
  132. data/vendor/faiss/gpu/GpuIndexIVF.h +89 -0
  133. data/vendor/faiss/gpu/GpuIndexIVFFlat.h +85 -0
  134. data/vendor/faiss/gpu/GpuIndexIVFPQ.h +143 -0
  135. data/vendor/faiss/gpu/GpuIndexIVFScalarQuantizer.h +100 -0
  136. data/vendor/faiss/gpu/GpuIndicesOptions.h +30 -0
  137. data/vendor/faiss/gpu/GpuResources.cpp +52 -0
  138. data/vendor/faiss/gpu/GpuResources.h +73 -0
  139. data/vendor/faiss/gpu/StandardGpuResources.cpp +295 -0
  140. data/vendor/faiss/gpu/StandardGpuResources.h +114 -0
  141. data/vendor/faiss/gpu/impl/RemapIndices.cpp +43 -0
  142. data/vendor/faiss/gpu/impl/RemapIndices.h +24 -0
  143. data/vendor/faiss/gpu/perf/IndexWrapper-inl.h +71 -0
  144. data/vendor/faiss/gpu/perf/IndexWrapper.h +39 -0
  145. data/vendor/faiss/gpu/perf/PerfClustering.cpp +115 -0
  146. data/vendor/faiss/gpu/perf/PerfIVFPQAdd.cpp +139 -0
  147. data/vendor/faiss/gpu/perf/WriteIndex.cpp +102 -0
  148. data/vendor/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +130 -0
  149. data/vendor/faiss/gpu/test/TestGpuIndexFlat.cpp +371 -0
  150. data/vendor/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +550 -0
  151. data/vendor/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +450 -0
  152. data/vendor/faiss/gpu/test/TestGpuMemoryException.cpp +84 -0
  153. data/vendor/faiss/gpu/test/TestUtils.cpp +315 -0
  154. data/vendor/faiss/gpu/test/TestUtils.h +93 -0
  155. data/vendor/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +159 -0
  156. data/vendor/faiss/gpu/utils/DeviceMemory.cpp +77 -0
  157. data/vendor/faiss/gpu/utils/DeviceMemory.h +71 -0
  158. data/vendor/faiss/gpu/utils/DeviceUtils.h +185 -0
  159. data/vendor/faiss/gpu/utils/MemorySpace.cpp +89 -0
  160. data/vendor/faiss/gpu/utils/MemorySpace.h +44 -0
  161. data/vendor/faiss/gpu/utils/StackDeviceMemory.cpp +239 -0
  162. data/vendor/faiss/gpu/utils/StackDeviceMemory.h +129 -0
  163. data/vendor/faiss/gpu/utils/StaticUtils.h +83 -0
  164. data/vendor/faiss/gpu/utils/Timer.cpp +60 -0
  165. data/vendor/faiss/gpu/utils/Timer.h +52 -0
  166. data/vendor/faiss/impl/AuxIndexStructures.cpp +305 -0
  167. data/vendor/faiss/impl/AuxIndexStructures.h +246 -0
  168. data/vendor/faiss/impl/FaissAssert.h +95 -0
  169. data/vendor/faiss/impl/FaissException.cpp +66 -0
  170. data/vendor/faiss/impl/FaissException.h +71 -0
  171. data/vendor/faiss/impl/HNSW.cpp +818 -0
  172. data/vendor/faiss/impl/HNSW.h +275 -0
  173. data/vendor/faiss/impl/PolysemousTraining.cpp +953 -0
  174. data/vendor/faiss/impl/PolysemousTraining.h +158 -0
  175. data/vendor/faiss/impl/ProductQuantizer.cpp +876 -0
  176. data/vendor/faiss/impl/ProductQuantizer.h +242 -0
  177. data/vendor/faiss/impl/ScalarQuantizer.cpp +1628 -0
  178. data/vendor/faiss/impl/ScalarQuantizer.h +120 -0
  179. data/vendor/faiss/impl/ThreadedIndex-inl.h +192 -0
  180. data/vendor/faiss/impl/ThreadedIndex.h +80 -0
  181. data/vendor/faiss/impl/index_read.cpp +793 -0
  182. data/vendor/faiss/impl/index_write.cpp +558 -0
  183. data/vendor/faiss/impl/io.cpp +142 -0
  184. data/vendor/faiss/impl/io.h +98 -0
  185. data/vendor/faiss/impl/lattice_Zn.cpp +712 -0
  186. data/vendor/faiss/impl/lattice_Zn.h +199 -0
  187. data/vendor/faiss/index_factory.cpp +392 -0
  188. data/vendor/faiss/index_factory.h +25 -0
  189. data/vendor/faiss/index_io.h +75 -0
  190. data/vendor/faiss/misc/test_blas.cpp +84 -0
  191. data/vendor/faiss/tests/test_binary_flat.cpp +64 -0
  192. data/vendor/faiss/tests/test_dealloc_invlists.cpp +183 -0
  193. data/vendor/faiss/tests/test_ivfpq_codec.cpp +67 -0
  194. data/vendor/faiss/tests/test_ivfpq_indexing.cpp +98 -0
  195. data/vendor/faiss/tests/test_lowlevel_ivf.cpp +566 -0
  196. data/vendor/faiss/tests/test_merge.cpp +258 -0
  197. data/vendor/faiss/tests/test_omp_threads.cpp +14 -0
  198. data/vendor/faiss/tests/test_ondisk_ivf.cpp +220 -0
  199. data/vendor/faiss/tests/test_pairs_decoding.cpp +189 -0
  200. data/vendor/faiss/tests/test_params_override.cpp +231 -0
  201. data/vendor/faiss/tests/test_pq_encoding.cpp +98 -0
  202. data/vendor/faiss/tests/test_sliding_ivf.cpp +240 -0
  203. data/vendor/faiss/tests/test_threaded_index.cpp +253 -0
  204. data/vendor/faiss/tests/test_transfer_invlists.cpp +159 -0
  205. data/vendor/faiss/tutorial/cpp/1-Flat.cpp +98 -0
  206. data/vendor/faiss/tutorial/cpp/2-IVFFlat.cpp +81 -0
  207. data/vendor/faiss/tutorial/cpp/3-IVFPQ.cpp +93 -0
  208. data/vendor/faiss/tutorial/cpp/4-GPU.cpp +119 -0
  209. data/vendor/faiss/tutorial/cpp/5-Multiple-GPUs.cpp +99 -0
  210. data/vendor/faiss/utils/Heap.cpp +122 -0
  211. data/vendor/faiss/utils/Heap.h +495 -0
  212. data/vendor/faiss/utils/WorkerThread.cpp +126 -0
  213. data/vendor/faiss/utils/WorkerThread.h +61 -0
  214. data/vendor/faiss/utils/distances.cpp +765 -0
  215. data/vendor/faiss/utils/distances.h +243 -0
  216. data/vendor/faiss/utils/distances_simd.cpp +809 -0
  217. data/vendor/faiss/utils/extra_distances.cpp +336 -0
  218. data/vendor/faiss/utils/extra_distances.h +54 -0
  219. data/vendor/faiss/utils/hamming-inl.h +472 -0
  220. data/vendor/faiss/utils/hamming.cpp +792 -0
  221. data/vendor/faiss/utils/hamming.h +220 -0
  222. data/vendor/faiss/utils/random.cpp +192 -0
  223. data/vendor/faiss/utils/random.h +60 -0
  224. data/vendor/faiss/utils/utils.cpp +783 -0
  225. data/vendor/faiss/utils/utils.h +181 -0
  226. metadata +216 -2
@@ -0,0 +1,242 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #ifndef FAISS_PRODUCT_QUANTIZER_H
11
+ #define FAISS_PRODUCT_QUANTIZER_H
12
+
13
+ #include <stdint.h>
14
+
15
+ #include <vector>
16
+
17
+ #include <faiss/Clustering.h>
18
+ #include <faiss/utils/Heap.h>
19
+
20
+ namespace faiss {
21
+
22
+ /** Product Quantizer. Implemented only for METRIC_L2 */
23
+ struct ProductQuantizer {
24
+
25
+ using idx_t = Index::idx_t;
26
+
27
+ size_t d; ///< size of the input vectors
28
+ size_t M; ///< number of subquantizers
29
+ size_t nbits; ///< number of bits per quantization index
30
+
31
+ // values derived from the above
32
+ size_t dsub; ///< dimensionality of each subvector
33
+ size_t code_size; ///< bytes per indexed vector
34
+ size_t ksub; ///< number of centroids for each subquantizer
35
+ bool verbose; ///< verbose during training?
36
+
37
+ /// initialization
38
+ enum train_type_t {
39
+ Train_default,
40
+ Train_hot_start, ///< the centroids are already initialized
41
+ Train_shared, ///< share dictionary accross PQ segments
42
+ Train_hypercube, ///< intialize centroids with nbits-D hypercube
43
+ Train_hypercube_pca, ///< intialize centroids with nbits-D hypercube
44
+ };
45
+ train_type_t train_type;
46
+
47
+ ClusteringParameters cp; ///< parameters used during clustering
48
+
49
+ /// if non-NULL, use this index for assignment (should be of size
50
+ /// d / M)
51
+ Index *assign_index;
52
+
53
+ /// Centroid table, size M * ksub * dsub
54
+ std::vector<float> centroids;
55
+
56
+ /// return the centroids associated with subvector m
57
+ float * get_centroids (size_t m, size_t i) {
58
+ return &centroids [(m * ksub + i) * dsub];
59
+ }
60
+ const float * get_centroids (size_t m, size_t i) const {
61
+ return &centroids [(m * ksub + i) * dsub];
62
+ }
63
+
64
+ // Train the product quantizer on a set of points. A clustering
65
+ // can be set on input to define non-default clustering parameters
66
+ void train (int n, const float *x);
67
+
68
+ ProductQuantizer(size_t d, /* dimensionality of the input vectors */
69
+ size_t M, /* number of subquantizers */
70
+ size_t nbits); /* number of bit per subvector index */
71
+
72
+ ProductQuantizer ();
73
+
74
+ /// compute derived values when d, M and nbits have been set
75
+ void set_derived_values ();
76
+
77
+ /// Define the centroids for subquantizer m
78
+ void set_params (const float * centroids, int m);
79
+
80
+ /// Quantize one vector with the product quantizer
81
+ void compute_code (const float * x, uint8_t * code) const ;
82
+
83
+ /// same as compute_code for several vectors
84
+ void compute_codes (const float * x,
85
+ uint8_t * codes,
86
+ size_t n) const ;
87
+
88
+ /// speed up code assignment using assign_index
89
+ /// (non-const because the index is changed)
90
+ void compute_codes_with_assign_index (
91
+ const float * x,
92
+ uint8_t * codes,
93
+ size_t n);
94
+
95
+ /// decode a vector from a given code (or n vectors if third argument)
96
+ void decode (const uint8_t *code, float *x) const;
97
+ void decode (const uint8_t *code, float *x, size_t n) const;
98
+
99
+ /// If we happen to have the distance tables precomputed, this is
100
+ /// more efficient to compute the codes.
101
+ void compute_code_from_distance_table (const float *tab,
102
+ uint8_t *code) const;
103
+
104
+
105
+ /** Compute distance table for one vector.
106
+ *
107
+ * The distance table for x = [x_0 x_1 .. x_(M-1)] is a M * ksub
108
+ * matrix that contains
109
+ *
110
+ * dis_table (m, j) = || x_m - c_(m, j)||^2
111
+ * for m = 0..M-1 and j = 0 .. ksub - 1
112
+ *
113
+ * where c_(m, j) is the centroid no j of sub-quantizer m.
114
+ *
115
+ * @param x input vector size d
116
+ * @param dis_table output table, size M * ksub
117
+ */
118
+ void compute_distance_table (const float * x,
119
+ float * dis_table) const;
120
+
121
+ void compute_inner_prod_table (const float * x,
122
+ float * dis_table) const;
123
+
124
+
125
+ /** compute distance table for several vectors
126
+ * @param nx nb of input vectors
127
+ * @param x input vector size nx * d
128
+ * @param dis_table output table, size nx * M * ksub
129
+ */
130
+ void compute_distance_tables (size_t nx,
131
+ const float * x,
132
+ float * dis_tables) const;
133
+
134
+ void compute_inner_prod_tables (size_t nx,
135
+ const float * x,
136
+ float * dis_tables) const;
137
+
138
+
139
+ /** perform a search (L2 distance)
140
+ * @param x query vectors, size nx * d
141
+ * @param nx nb of queries
142
+ * @param codes database codes, size ncodes * code_size
143
+ * @param ncodes nb of nb vectors
144
+ * @param res heap array to store results (nh == nx)
145
+ * @param init_finalize_heap initialize heap (input) and sort (output)?
146
+ */
147
+ void search (const float * x,
148
+ size_t nx,
149
+ const uint8_t * codes,
150
+ const size_t ncodes,
151
+ float_maxheap_array_t *res,
152
+ bool init_finalize_heap = true) const;
153
+
154
+ /** same as search, but with inner product similarity */
155
+ void search_ip (const float * x,
156
+ size_t nx,
157
+ const uint8_t * codes,
158
+ const size_t ncodes,
159
+ float_minheap_array_t *res,
160
+ bool init_finalize_heap = true) const;
161
+
162
+
163
+ /// Symmetric Distance Table
164
+ std::vector<float> sdc_table;
165
+
166
+ // intitialize the SDC table from the centroids
167
+ void compute_sdc_table ();
168
+
169
+ void search_sdc (const uint8_t * qcodes,
170
+ size_t nq,
171
+ const uint8_t * bcodes,
172
+ const size_t ncodes,
173
+ float_maxheap_array_t * res,
174
+ bool init_finalize_heap = true) const;
175
+
176
+ struct PQEncoderGeneric {
177
+ uint8_t *code; ///< code for this vector
178
+ uint8_t offset;
179
+ const int nbits; ///< number of bits per subquantizer index
180
+
181
+ uint8_t reg;
182
+
183
+ PQEncoderGeneric(uint8_t *code, int nbits, uint8_t offset = 0);
184
+
185
+ void encode(uint64_t x);
186
+
187
+ ~PQEncoderGeneric();
188
+ };
189
+
190
+
191
+ struct PQEncoder8 {
192
+ uint8_t *code;
193
+
194
+ PQEncoder8(uint8_t *code, int nbits);
195
+
196
+ void encode(uint64_t x);
197
+ };
198
+
199
+ struct PQEncoder16 {
200
+ uint16_t *code;
201
+
202
+ PQEncoder16(uint8_t *code, int nbits);
203
+
204
+ void encode(uint64_t x);
205
+ };
206
+
207
+
208
+ struct PQDecoderGeneric {
209
+ const uint8_t *code;
210
+ uint8_t offset;
211
+ const int nbits;
212
+ const uint64_t mask;
213
+ uint8_t reg;
214
+
215
+ PQDecoderGeneric(const uint8_t *code, int nbits);
216
+
217
+ uint64_t decode();
218
+ };
219
+
220
+ struct PQDecoder8 {
221
+ const uint8_t *code;
222
+
223
+ PQDecoder8(const uint8_t *code, int nbits);
224
+
225
+ uint64_t decode();
226
+ };
227
+
228
+ struct PQDecoder16 {
229
+ const uint16_t *code;
230
+
231
+ PQDecoder16(const uint8_t *code, int nbits);
232
+
233
+ uint64_t decode();
234
+ };
235
+
236
+ };
237
+
238
+
239
+ } // namespace faiss
240
+
241
+
242
+ #endif
@@ -0,0 +1,1628 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // -*- c++ -*-
9
+
10
+ #include <faiss/impl/ScalarQuantizer.h>
11
+
12
+ #include <cstdio>
13
+ #include <algorithm>
14
+
15
+ #include <omp.h>
16
+
17
+ #ifdef __SSE__
18
+ #include <immintrin.h>
19
+ #endif
20
+
21
+ #include <faiss/utils/utils.h>
22
+ #include <faiss/impl/FaissAssert.h>
23
+ #include <faiss/impl/AuxIndexStructures.h>
24
+
25
+ namespace faiss {
26
+
27
+ /*******************************************************************
28
+ * ScalarQuantizer implementation
29
+ *
30
+ * The main source of complexity is to support combinations of 4
31
+ * variants without incurring runtime tests or virtual function calls:
32
+ *
33
+ * - 4 / 8 bits per code component
34
+ * - uniform / non-uniform
35
+ * - IP / L2 distance search
36
+ * - scalar / AVX distance computation
37
+ *
38
+ * The appropriate Quantizer object is returned via select_quantizer
39
+ * that hides the template mess.
40
+ ********************************************************************/
41
+
42
+ #ifdef __AVX__
43
+ #define USE_AVX
44
+ #endif
45
+
46
+ #ifdef __F16C__
47
+ #define USE_F16C
48
+ #endif
49
+
50
+
51
+ namespace {
52
+
53
+ typedef Index::idx_t idx_t;
54
+ typedef ScalarQuantizer::QuantizerType QuantizerType;
55
+ typedef ScalarQuantizer::RangeStat RangeStat;
56
+ using SQDistanceComputer = ScalarQuantizer::SQDistanceComputer;
57
+
58
+
59
+ /*******************************************************************
60
+ * Codec: converts between values in [0, 1] and an index in a code
61
+ * array. The "i" parameter is the vector component index (not byte
62
+ * index).
63
+ */
64
+
65
+ struct Codec8bit {
66
+
67
+ static void encode_component (float x, uint8_t *code, int i) {
68
+ code[i] = (int)(255 * x);
69
+ }
70
+
71
+ static float decode_component (const uint8_t *code, int i) {
72
+ return (code[i] + 0.5f) / 255.0f;
73
+ }
74
+
75
+ #ifdef USE_AVX
76
+ static __m256 decode_8_components (const uint8_t *code, int i) {
77
+ uint64_t c8 = *(uint64_t*)(code + i);
78
+ __m128i c4lo = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8));
79
+ __m128i c4hi = _mm_cvtepu8_epi32 (_mm_set1_epi32(c8 >> 32));
80
+ // __m256i i8 = _mm256_set_m128i(c4lo, c4hi);
81
+ __m256i i8 = _mm256_castsi128_si256 (c4lo);
82
+ i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
83
+ __m256 f8 = _mm256_cvtepi32_ps (i8);
84
+ __m256 half = _mm256_set1_ps (0.5f);
85
+ f8 += half;
86
+ __m256 one_255 = _mm256_set1_ps (1.f / 255.f);
87
+ return f8 * one_255;
88
+ }
89
+ #endif
90
+ };
91
+
92
+
93
+ struct Codec4bit {
94
+
95
+ static void encode_component (float x, uint8_t *code, int i) {
96
+ code [i / 2] |= (int)(x * 15.0) << ((i & 1) << 2);
97
+ }
98
+
99
+ static float decode_component (const uint8_t *code, int i) {
100
+ return (((code[i / 2] >> ((i & 1) << 2)) & 0xf) + 0.5f) / 15.0f;
101
+ }
102
+
103
+
104
+ #ifdef USE_AVX
105
+ static __m256 decode_8_components (const uint8_t *code, int i) {
106
+ uint32_t c4 = *(uint32_t*)(code + (i >> 1));
107
+ uint32_t mask = 0x0f0f0f0f;
108
+ uint32_t c4ev = c4 & mask;
109
+ uint32_t c4od = (c4 >> 4) & mask;
110
+
111
+ // the 8 lower bytes of c8 contain the values
112
+ __m128i c8 = _mm_unpacklo_epi8 (_mm_set1_epi32(c4ev),
113
+ _mm_set1_epi32(c4od));
114
+ __m128i c4lo = _mm_cvtepu8_epi32 (c8);
115
+ __m128i c4hi = _mm_cvtepu8_epi32 (_mm_srli_si128(c8, 4));
116
+ __m256i i8 = _mm256_castsi128_si256 (c4lo);
117
+ i8 = _mm256_insertf128_si256 (i8, c4hi, 1);
118
+ __m256 f8 = _mm256_cvtepi32_ps (i8);
119
+ __m256 half = _mm256_set1_ps (0.5f);
120
+ f8 += half;
121
+ __m256 one_255 = _mm256_set1_ps (1.f / 15.f);
122
+ return f8 * one_255;
123
+ }
124
+ #endif
125
+ };
126
+
127
+ struct Codec6bit {
128
+
129
+ static void encode_component (float x, uint8_t *code, int i) {
130
+ int bits = (int)(x * 63.0);
131
+ code += (i >> 2) * 3;
132
+ switch(i & 3) {
133
+ case 0:
134
+ code[0] |= bits;
135
+ break;
136
+ case 1:
137
+ code[0] |= bits << 6;
138
+ code[1] |= bits >> 2;
139
+ break;
140
+ case 2:
141
+ code[1] |= bits << 4;
142
+ code[2] |= bits >> 4;
143
+ break;
144
+ case 3:
145
+ code[2] |= bits << 2;
146
+ break;
147
+ }
148
+ }
149
+
150
+ static float decode_component (const uint8_t *code, int i) {
151
+ uint8_t bits;
152
+ code += (i >> 2) * 3;
153
+ switch(i & 3) {
154
+ case 0:
155
+ bits = code[0] & 0x3f;
156
+ break;
157
+ case 1:
158
+ bits = code[0] >> 6;
159
+ bits |= (code[1] & 0xf) << 2;
160
+ break;
161
+ case 2:
162
+ bits = code[1] >> 4;
163
+ bits |= (code[2] & 3) << 4;
164
+ break;
165
+ case 3:
166
+ bits = code[2] >> 2;
167
+ break;
168
+ }
169
+ return (bits + 0.5f) / 63.0f;
170
+ }
171
+
172
+ #ifdef USE_AVX
173
+ static __m256 decode_8_components (const uint8_t *code, int i) {
174
+ return _mm256_set_ps
175
+ (decode_component(code, i + 7),
176
+ decode_component(code, i + 6),
177
+ decode_component(code, i + 5),
178
+ decode_component(code, i + 4),
179
+ decode_component(code, i + 3),
180
+ decode_component(code, i + 2),
181
+ decode_component(code, i + 1),
182
+ decode_component(code, i + 0));
183
+ }
184
+ #endif
185
+ };
186
+
187
+
188
+
189
+ #ifdef USE_F16C
190
+
191
+
192
+ uint16_t encode_fp16 (float x) {
193
+ __m128 xf = _mm_set1_ps (x);
194
+ __m128i xi = _mm_cvtps_ph (
195
+ xf, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC);
196
+ return _mm_cvtsi128_si32 (xi) & 0xffff;
197
+ }
198
+
199
+
200
+ float decode_fp16 (uint16_t x) {
201
+ __m128i xi = _mm_set1_epi16 (x);
202
+ __m128 xf = _mm_cvtph_ps (xi);
203
+ return _mm_cvtss_f32 (xf);
204
+ }
205
+
206
+ #else
207
+
208
+ // non-intrinsic FP16 <-> FP32 code adapted from
209
+ // https://github.com/ispc/ispc/blob/master/stdlib.ispc
210
+
211
+ float floatbits (uint32_t x) {
212
+ void *xptr = &x;
213
+ return *(float*)xptr;
214
+ }
215
+
216
+ uint32_t intbits (float f) {
217
+ void *fptr = &f;
218
+ return *(uint32_t*)fptr;
219
+ }
220
+
221
+
222
+ uint16_t encode_fp16 (float f) {
223
+
224
+ // via Fabian "ryg" Giesen.
225
+ // https://gist.github.com/2156668
226
+ uint32_t sign_mask = 0x80000000u;
227
+ int32_t o;
228
+
229
+ uint32_t fint = intbits(f);
230
+ uint32_t sign = fint & sign_mask;
231
+ fint ^= sign;
232
+
233
+ // NOTE all the integer compares in this function can be safely
234
+ // compiled into signed compares since all operands are below
235
+ // 0x80000000. Important if you want fast straight SSE2 code (since
236
+ // there's no unsigned PCMPGTD).
237
+
238
+ // Inf or NaN (all exponent bits set)
239
+ // NaN->qNaN and Inf->Inf
240
+ // unconditional assignment here, will override with right value for
241
+ // the regular case below.
242
+ uint32_t f32infty = 255u << 23;
243
+ o = (fint > f32infty) ? 0x7e00u : 0x7c00u;
244
+
245
+ // (De)normalized number or zero
246
+ // update fint unconditionally to save the blending; we don't need it
247
+ // anymore for the Inf/NaN case anyway.
248
+
249
+ const uint32_t round_mask = ~0xfffu;
250
+ const uint32_t magic = 15u << 23;
251
+
252
+ // Shift exponent down, denormalize if necessary.
253
+ // NOTE This represents half-float denormals using single
254
+ // precision denormals. The main reason to do this is that
255
+ // there's no shift with per-lane variable shifts in SSE*, which
256
+ // we'd otherwise need. It has some funky side effects though:
257
+ // - This conversion will actually respect the FTZ (Flush To Zero)
258
+ // flag in MXCSR - if it's set, no half-float denormals will be
259
+ // generated. I'm honestly not sure whether this is good or
260
+ // bad. It's definitely interesting.
261
+ // - If the underlying HW doesn't support denormals (not an issue
262
+ // with Intel CPUs, but might be a problem on GPUs or PS3 SPUs),
263
+ // you will always get flush-to-zero behavior. This is bad,
264
+ // unless you're on a CPU where you don't care.
265
+ // - Denormals tend to be slow. FP32 denormals are rare in
266
+ // practice outside of things like recursive filters in DSP -
267
+ // not a typical half-float application. Whether FP16 denormals
268
+ // are rare in practice, I don't know. Whatever slow path your
269
+ // HW may or may not have for denormals, this may well hit it.
270
+ float fscale = floatbits(fint & round_mask) * floatbits(magic);
271
+ fscale = std::min(fscale, floatbits((31u << 23) - 0x1000u));
272
+ int32_t fint2 = intbits(fscale) - round_mask;
273
+
274
+ if (fint < f32infty)
275
+ o = fint2 >> 13; // Take the bits!
276
+
277
+ return (o | (sign >> 16));
278
+ }
279
+
280
+ float decode_fp16 (uint16_t h) {
281
+
282
+ // https://gist.github.com/2144712
283
+ // Fabian "ryg" Giesen.
284
+
285
+ const uint32_t shifted_exp = 0x7c00u << 13; // exponent mask after shift
286
+
287
+ int32_t o = ((int32_t)(h & 0x7fffu)) << 13; // exponent/mantissa bits
288
+ int32_t exp = shifted_exp & o; // just the exponent
289
+ o += (int32_t)(127 - 15) << 23; // exponent adjust
290
+
291
+ int32_t infnan_val = o + ((int32_t)(128 - 16) << 23);
292
+ int32_t zerodenorm_val = intbits(
293
+ floatbits(o + (1u<<23)) - floatbits(113u << 23));
294
+ int32_t reg_val = (exp == 0) ? zerodenorm_val : o;
295
+
296
+ int32_t sign_bit = ((int32_t)(h & 0x8000u)) << 16;
297
+ return floatbits(((exp == shifted_exp) ? infnan_val : reg_val) | sign_bit);
298
+ }
299
+
300
+ #endif
301
+
302
+
303
+
304
+ /*******************************************************************
305
+ * Quantizer: normalizes scalar vector components, then passes them
306
+ * through a codec
307
+ *******************************************************************/
308
+
309
+
310
+
311
+
312
+
313
+ template<class Codec, bool uniform, int SIMD>
314
+ struct QuantizerTemplate {};
315
+
316
+
317
+ template<class Codec>
318
+ struct QuantizerTemplate<Codec, true, 1>: ScalarQuantizer::Quantizer {
319
+ const size_t d;
320
+ const float vmin, vdiff;
321
+
322
+ QuantizerTemplate(size_t d, const std::vector<float> &trained):
323
+ d(d), vmin(trained[0]), vdiff(trained[1])
324
+ {
325
+ }
326
+
327
+ void encode_vector(const float* x, uint8_t* code) const final {
328
+ for (size_t i = 0; i < d; i++) {
329
+ float xi = (x[i] - vmin) / vdiff;
330
+ if (xi < 0) {
331
+ xi = 0;
332
+ }
333
+ if (xi > 1.0) {
334
+ xi = 1.0;
335
+ }
336
+ Codec::encode_component(xi, code, i);
337
+ }
338
+ }
339
+
340
+ void decode_vector(const uint8_t* code, float* x) const final {
341
+ for (size_t i = 0; i < d; i++) {
342
+ float xi = Codec::decode_component(code, i);
343
+ x[i] = vmin + xi * vdiff;
344
+ }
345
+ }
346
+
347
+ float reconstruct_component (const uint8_t * code, int i) const
348
+ {
349
+ float xi = Codec::decode_component (code, i);
350
+ return vmin + xi * vdiff;
351
+ }
352
+
353
+ };
354
+
355
+
356
+
357
+ #ifdef USE_AVX
358
+
359
+ template<class Codec>
360
+ struct QuantizerTemplate<Codec, true, 8>: QuantizerTemplate<Codec, true, 1> {
361
+
362
+ QuantizerTemplate (size_t d, const std::vector<float> &trained):
363
+ QuantizerTemplate<Codec, true, 1> (d, trained) {}
364
+
365
+ __m256 reconstruct_8_components (const uint8_t * code, int i) const
366
+ {
367
+ __m256 xi = Codec::decode_8_components (code, i);
368
+ return _mm256_set1_ps(this->vmin) + xi * _mm256_set1_ps (this->vdiff);
369
+ }
370
+
371
+ };
372
+
373
+ #endif
374
+
375
+
376
+
377
+ template<class Codec>
378
+ struct QuantizerTemplate<Codec, false, 1>: ScalarQuantizer::Quantizer {
379
+ const size_t d;
380
+ const float *vmin, *vdiff;
381
+
382
+ QuantizerTemplate (size_t d, const std::vector<float> &trained):
383
+ d(d), vmin(trained.data()), vdiff(trained.data() + d) {}
384
+
385
+ void encode_vector(const float* x, uint8_t* code) const final {
386
+ for (size_t i = 0; i < d; i++) {
387
+ float xi = (x[i] - vmin[i]) / vdiff[i];
388
+ if (xi < 0)
389
+ xi = 0;
390
+ if (xi > 1.0)
391
+ xi = 1.0;
392
+ Codec::encode_component(xi, code, i);
393
+ }
394
+ }
395
+
396
+ void decode_vector(const uint8_t* code, float* x) const final {
397
+ for (size_t i = 0; i < d; i++) {
398
+ float xi = Codec::decode_component(code, i);
399
+ x[i] = vmin[i] + xi * vdiff[i];
400
+ }
401
+ }
402
+
403
+ float reconstruct_component (const uint8_t * code, int i) const
404
+ {
405
+ float xi = Codec::decode_component (code, i);
406
+ return vmin[i] + xi * vdiff[i];
407
+ }
408
+
409
+ };
410
+
411
+
412
+ #ifdef USE_AVX
413
+
414
+ template<class Codec>
415
+ struct QuantizerTemplate<Codec, false, 8>: QuantizerTemplate<Codec, false, 1> {
416
+
417
+ QuantizerTemplate (size_t d, const std::vector<float> &trained):
418
+ QuantizerTemplate<Codec, false, 1> (d, trained) {}
419
+
420
+ __m256 reconstruct_8_components (const uint8_t * code, int i) const
421
+ {
422
+ __m256 xi = Codec::decode_8_components (code, i);
423
+ return _mm256_loadu_ps (this->vmin + i) + xi * _mm256_loadu_ps (this->vdiff + i);
424
+ }
425
+
426
+
427
+ };
428
+
429
+ #endif
430
+
431
+ /*******************************************************************
432
+ * FP16 quantizer
433
+ *******************************************************************/
434
+
435
+ template<int SIMDWIDTH>
436
+ struct QuantizerFP16 {};
437
+
438
+ template<>
439
+ struct QuantizerFP16<1>: ScalarQuantizer::Quantizer {
440
+ const size_t d;
441
+
442
+ QuantizerFP16(size_t d, const std::vector<float> & /* unused */):
443
+ d(d) {}
444
+
445
+ void encode_vector(const float* x, uint8_t* code) const final {
446
+ for (size_t i = 0; i < d; i++) {
447
+ ((uint16_t*)code)[i] = encode_fp16(x[i]);
448
+ }
449
+ }
450
+
451
+ void decode_vector(const uint8_t* code, float* x) const final {
452
+ for (size_t i = 0; i < d; i++) {
453
+ x[i] = decode_fp16(((uint16_t*)code)[i]);
454
+ }
455
+ }
456
+
457
+ float reconstruct_component (const uint8_t * code, int i) const
458
+ {
459
+ return decode_fp16(((uint16_t*)code)[i]);
460
+ }
461
+
462
+ };
463
+
464
+ #ifdef USE_F16C
465
+
466
+ template<>
467
+ struct QuantizerFP16<8>: QuantizerFP16<1> {
468
+
469
+ QuantizerFP16 (size_t d, const std::vector<float> &trained):
470
+ QuantizerFP16<1> (d, trained) {}
471
+
472
+ __m256 reconstruct_8_components (const uint8_t * code, int i) const
473
+ {
474
+ __m128i codei = _mm_loadu_si128 ((const __m128i*)(code + 2 * i));
475
+ return _mm256_cvtph_ps (codei);
476
+ }
477
+
478
+ };
479
+
480
+ #endif
481
+
482
+ /*******************************************************************
483
+ * 8bit_direct quantizer
484
+ *******************************************************************/
485
+
486
+ template<int SIMDWIDTH>
487
+ struct Quantizer8bitDirect {};
488
+
489
+ template<>
490
+ struct Quantizer8bitDirect<1>: ScalarQuantizer::Quantizer {
491
+ const size_t d;
492
+
493
+ Quantizer8bitDirect(size_t d, const std::vector<float> & /* unused */):
494
+ d(d) {}
495
+
496
+
497
+ void encode_vector(const float* x, uint8_t* code) const final {
498
+ for (size_t i = 0; i < d; i++) {
499
+ code[i] = (uint8_t)x[i];
500
+ }
501
+ }
502
+
503
+ void decode_vector(const uint8_t* code, float* x) const final {
504
+ for (size_t i = 0; i < d; i++) {
505
+ x[i] = code[i];
506
+ }
507
+ }
508
+
509
+ float reconstruct_component (const uint8_t * code, int i) const
510
+ {
511
+ return code[i];
512
+ }
513
+
514
+ };
515
+
516
+ #ifdef USE_AVX
517
+
518
+ template<>
519
+ struct Quantizer8bitDirect<8>: Quantizer8bitDirect<1> {
520
+
521
+ Quantizer8bitDirect (size_t d, const std::vector<float> &trained):
522
+ Quantizer8bitDirect<1> (d, trained) {}
523
+
524
+ __m256 reconstruct_8_components (const uint8_t * code, int i) const
525
+ {
526
+ __m128i x8 = _mm_loadl_epi64((__m128i*)(code + i)); // 8 * int8
527
+ __m256i y8 = _mm256_cvtepu8_epi32 (x8); // 8 * int32
528
+ return _mm256_cvtepi32_ps (y8); // 8 * float32
529
+ }
530
+
531
+ };
532
+
533
+ #endif
534
+
535
+
536
+ template<int SIMDWIDTH>
537
+ ScalarQuantizer::Quantizer *select_quantizer_1 (
538
+ QuantizerType qtype,
539
+ size_t d, const std::vector<float> & trained)
540
+ {
541
+ switch(qtype) {
542
+ case ScalarQuantizer::QT_8bit:
543
+ return new QuantizerTemplate<Codec8bit, false, SIMDWIDTH>(d, trained);
544
+ case ScalarQuantizer::QT_6bit:
545
+ return new QuantizerTemplate<Codec6bit, false, SIMDWIDTH>(d, trained);
546
+ case ScalarQuantizer::QT_4bit:
547
+ return new QuantizerTemplate<Codec4bit, false, SIMDWIDTH>(d, trained);
548
+ case ScalarQuantizer::QT_8bit_uniform:
549
+ return new QuantizerTemplate<Codec8bit, true, SIMDWIDTH>(d, trained);
550
+ case ScalarQuantizer::QT_4bit_uniform:
551
+ return new QuantizerTemplate<Codec4bit, true, SIMDWIDTH>(d, trained);
552
+ case ScalarQuantizer::QT_fp16:
553
+ return new QuantizerFP16<SIMDWIDTH> (d, trained);
554
+ case ScalarQuantizer::QT_8bit_direct:
555
+ return new Quantizer8bitDirect<SIMDWIDTH> (d, trained);
556
+ }
557
+ FAISS_THROW_MSG ("unknown qtype");
558
+ }
559
+
560
+
561
+
562
+
563
+ /*******************************************************************
564
+ * Quantizer range training
565
+ */
566
+
567
+ static float sqr (float x) {
568
+ return x * x;
569
+ }
570
+
571
+
572
+ void train_Uniform(RangeStat rs, float rs_arg,
573
+ idx_t n, int k, const float *x,
574
+ std::vector<float> & trained)
575
+ {
576
+ trained.resize (2);
577
+ float & vmin = trained[0];
578
+ float & vmax = trained[1];
579
+
580
+ if (rs == ScalarQuantizer::RS_minmax) {
581
+ vmin = HUGE_VAL; vmax = -HUGE_VAL;
582
+ for (size_t i = 0; i < n; i++) {
583
+ if (x[i] < vmin) vmin = x[i];
584
+ if (x[i] > vmax) vmax = x[i];
585
+ }
586
+ float vexp = (vmax - vmin) * rs_arg;
587
+ vmin -= vexp;
588
+ vmax += vexp;
589
+ } else if (rs == ScalarQuantizer::RS_meanstd) {
590
+ double sum = 0, sum2 = 0;
591
+ for (size_t i = 0; i < n; i++) {
592
+ sum += x[i];
593
+ sum2 += x[i] * x[i];
594
+ }
595
+ float mean = sum / n;
596
+ float var = sum2 / n - mean * mean;
597
+ float std = var <= 0 ? 1.0 : sqrt(var);
598
+
599
+ vmin = mean - std * rs_arg ;
600
+ vmax = mean + std * rs_arg ;
601
+ } else if (rs == ScalarQuantizer::RS_quantiles) {
602
+ std::vector<float> x_copy(n);
603
+ memcpy(x_copy.data(), x, n * sizeof(*x));
604
+ // TODO just do a qucikselect
605
+ std::sort(x_copy.begin(), x_copy.end());
606
+ int o = int(rs_arg * n);
607
+ if (o < 0) o = 0;
608
+ if (o > n - o) o = n / 2;
609
+ vmin = x_copy[o];
610
+ vmax = x_copy[n - 1 - o];
611
+
612
+ } else if (rs == ScalarQuantizer::RS_optim) {
613
+ float a, b;
614
+ float sx = 0;
615
+ {
616
+ vmin = HUGE_VAL, vmax = -HUGE_VAL;
617
+ for (size_t i = 0; i < n; i++) {
618
+ if (x[i] < vmin) vmin = x[i];
619
+ if (x[i] > vmax) vmax = x[i];
620
+ sx += x[i];
621
+ }
622
+ b = vmin;
623
+ a = (vmax - vmin) / (k - 1);
624
+ }
625
+ int verbose = false;
626
+ int niter = 2000;
627
+ float last_err = -1;
628
+ int iter_last_err = 0;
629
+ for (int it = 0; it < niter; it++) {
630
+ float sn = 0, sn2 = 0, sxn = 0, err1 = 0;
631
+
632
+ for (idx_t i = 0; i < n; i++) {
633
+ float xi = x[i];
634
+ float ni = floor ((xi - b) / a + 0.5);
635
+ if (ni < 0) ni = 0;
636
+ if (ni >= k) ni = k - 1;
637
+ err1 += sqr (xi - (ni * a + b));
638
+ sn += ni;
639
+ sn2 += ni * ni;
640
+ sxn += ni * xi;
641
+ }
642
+
643
+ if (err1 == last_err) {
644
+ iter_last_err ++;
645
+ if (iter_last_err == 16) break;
646
+ } else {
647
+ last_err = err1;
648
+ iter_last_err = 0;
649
+ }
650
+
651
+ float det = sqr (sn) - sn2 * n;
652
+
653
+ b = (sn * sxn - sn2 * sx) / det;
654
+ a = (sn * sx - n * sxn) / det;
655
+ if (verbose) {
656
+ printf ("it %d, err1=%g \r", it, err1);
657
+ fflush(stdout);
658
+ }
659
+ }
660
+ if (verbose) printf("\n");
661
+
662
+ vmin = b;
663
+ vmax = b + a * (k - 1);
664
+
665
+ } else {
666
+ FAISS_THROW_MSG ("Invalid qtype");
667
+ }
668
+ vmax -= vmin;
669
+ }
670
+
671
+ void train_NonUniform(RangeStat rs, float rs_arg,
672
+ idx_t n, int d, int k, const float *x,
673
+ std::vector<float> & trained)
674
+ {
675
+
676
+ trained.resize (2 * d);
677
+ float * vmin = trained.data();
678
+ float * vmax = trained.data() + d;
679
+ if (rs == ScalarQuantizer::RS_minmax) {
680
+ memcpy (vmin, x, sizeof(*x) * d);
681
+ memcpy (vmax, x, sizeof(*x) * d);
682
+ for (size_t i = 1; i < n; i++) {
683
+ const float *xi = x + i * d;
684
+ for (size_t j = 0; j < d; j++) {
685
+ if (xi[j] < vmin[j]) vmin[j] = xi[j];
686
+ if (xi[j] > vmax[j]) vmax[j] = xi[j];
687
+ }
688
+ }
689
+ float *vdiff = vmax;
690
+ for (size_t j = 0; j < d; j++) {
691
+ float vexp = (vmax[j] - vmin[j]) * rs_arg;
692
+ vmin[j] -= vexp;
693
+ vmax[j] += vexp;
694
+ vdiff [j] = vmax[j] - vmin[j];
695
+ }
696
+ } else {
697
+ // transpose
698
+ std::vector<float> xt(n * d);
699
+ for (size_t i = 1; i < n; i++) {
700
+ const float *xi = x + i * d;
701
+ for (size_t j = 0; j < d; j++) {
702
+ xt[j * n + i] = xi[j];
703
+ }
704
+ }
705
+ std::vector<float> trained_d(2);
706
+ #pragma omp parallel for
707
+ for (size_t j = 0; j < d; j++) {
708
+ train_Uniform(rs, rs_arg,
709
+ n, k, xt.data() + j * n,
710
+ trained_d);
711
+ vmin[j] = trained_d[0];
712
+ vmax[j] = trained_d[1];
713
+ }
714
+ }
715
+ }
716
+
717
+
718
+
719
+ /*******************************************************************
720
+ * Similarity: gets vector components and computes a similarity wrt. a
721
+ * query vector stored in the object. The data fields just encapsulate
722
+ * an accumulator.
723
+ */
724
+
725
+ template<int SIMDWIDTH>
726
+ struct SimilarityL2 {};
727
+
728
+
729
+ template<>
730
+ struct SimilarityL2<1> {
731
+ static constexpr int simdwidth = 1;
732
+ static constexpr MetricType metric_type = METRIC_L2;
733
+
734
+ const float *y, *yi;
735
+
736
+ explicit SimilarityL2 (const float * y): y(y) {}
737
+
738
+ /******* scalar accumulator *******/
739
+
740
+ float accu;
741
+
742
+ void begin () {
743
+ accu = 0;
744
+ yi = y;
745
+ }
746
+
747
+ void add_component (float x) {
748
+ float tmp = *yi++ - x;
749
+ accu += tmp * tmp;
750
+ }
751
+
752
+ void add_component_2 (float x1, float x2) {
753
+ float tmp = x1 - x2;
754
+ accu += tmp * tmp;
755
+ }
756
+
757
+ float result () {
758
+ return accu;
759
+ }
760
+ };
761
+
762
+
763
+ #ifdef USE_AVX
764
+ template<>
765
+ struct SimilarityL2<8> {
766
+ static constexpr int simdwidth = 8;
767
+ static constexpr MetricType metric_type = METRIC_L2;
768
+
769
+ const float *y, *yi;
770
+
771
+ explicit SimilarityL2 (const float * y): y(y) {}
772
+ __m256 accu8;
773
+
774
+ void begin_8 () {
775
+ accu8 = _mm256_setzero_ps();
776
+ yi = y;
777
+ }
778
+
779
+ void add_8_components (__m256 x) {
780
+ __m256 yiv = _mm256_loadu_ps (yi);
781
+ yi += 8;
782
+ __m256 tmp = yiv - x;
783
+ accu8 += tmp * tmp;
784
+ }
785
+
786
+ void add_8_components_2 (__m256 x, __m256 y) {
787
+ __m256 tmp = y - x;
788
+ accu8 += tmp * tmp;
789
+ }
790
+
791
+ float result_8 () {
792
+ __m256 sum = _mm256_hadd_ps(accu8, accu8);
793
+ __m256 sum2 = _mm256_hadd_ps(sum, sum);
794
+ // now add the 0th and 4th component
795
+ return
796
+ _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
797
+ _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
798
+ }
799
+
800
+ };
801
+
802
+ #endif
803
+
804
+
805
+ template<int SIMDWIDTH>
806
+ struct SimilarityIP {};
807
+
808
+
809
+ template<>
810
+ struct SimilarityIP<1> {
811
+ static constexpr int simdwidth = 1;
812
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
813
+ const float *y, *yi;
814
+
815
+ float accu;
816
+
817
+ explicit SimilarityIP (const float * y):
818
+ y (y) {}
819
+
820
+ void begin () {
821
+ accu = 0;
822
+ yi = y;
823
+ }
824
+
825
+ void add_component (float x) {
826
+ accu += *yi++ * x;
827
+ }
828
+
829
+ void add_component_2 (float x1, float x2) {
830
+ accu += x1 * x2;
831
+ }
832
+
833
+ float result () {
834
+ return accu;
835
+ }
836
+ };
837
+
838
+ #ifdef USE_AVX
839
+
840
+ template<>
841
+ struct SimilarityIP<8> {
842
+ static constexpr int simdwidth = 8;
843
+ static constexpr MetricType metric_type = METRIC_INNER_PRODUCT;
844
+
845
+ const float *y, *yi;
846
+
847
+ float accu;
848
+
849
+ explicit SimilarityIP (const float * y):
850
+ y (y) {}
851
+
852
+ __m256 accu8;
853
+
854
+ void begin_8 () {
855
+ accu8 = _mm256_setzero_ps();
856
+ yi = y;
857
+ }
858
+
859
+ void add_8_components (__m256 x) {
860
+ __m256 yiv = _mm256_loadu_ps (yi);
861
+ yi += 8;
862
+ accu8 += yiv * x;
863
+ }
864
+
865
+ void add_8_components_2 (__m256 x1, __m256 x2) {
866
+ accu8 += x1 * x2;
867
+ }
868
+
869
+ float result_8 () {
870
+ __m256 sum = _mm256_hadd_ps(accu8, accu8);
871
+ __m256 sum2 = _mm256_hadd_ps(sum, sum);
872
+ // now add the 0th and 4th component
873
+ return
874
+ _mm_cvtss_f32 (_mm256_castps256_ps128(sum2)) +
875
+ _mm_cvtss_f32 (_mm256_extractf128_ps(sum2, 1));
876
+ }
877
+ };
878
+ #endif
879
+
880
+
881
+ /*******************************************************************
882
+ * DistanceComputer: combines a similarity and a quantizer to do
883
+ * code-to-vector or code-to-code comparisons
884
+ *******************************************************************/
885
+
886
+ template<class Quantizer, class Similarity, int SIMDWIDTH>
887
+ struct DCTemplate : SQDistanceComputer {};
888
+
889
+ template<class Quantizer, class Similarity>
890
+ struct DCTemplate<Quantizer, Similarity, 1> : SQDistanceComputer
891
+ {
892
+ using Sim = Similarity;
893
+
894
+ Quantizer quant;
895
+
896
+ DCTemplate(size_t d, const std::vector<float> &trained):
897
+ quant(d, trained)
898
+ {}
899
+
900
+ float compute_distance(const float* x, const uint8_t* code) const {
901
+
902
+ Similarity sim(x);
903
+ sim.begin();
904
+ for (size_t i = 0; i < quant.d; i++) {
905
+ float xi = quant.reconstruct_component(code, i);
906
+ sim.add_component(xi);
907
+ }
908
+ return sim.result();
909
+ }
910
+
911
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
912
+ const {
913
+ Similarity sim(nullptr);
914
+ sim.begin();
915
+ for (size_t i = 0; i < quant.d; i++) {
916
+ float x1 = quant.reconstruct_component(code1, i);
917
+ float x2 = quant.reconstruct_component(code2, i);
918
+ sim.add_component_2(x1, x2);
919
+ }
920
+ return sim.result();
921
+ }
922
+
923
+ void set_query (const float *x) final {
924
+ q = x;
925
+ }
926
+
927
+ /// compute distance of vector i to current query
928
+ float operator () (idx_t i) final {
929
+ return compute_distance (q, codes + i * code_size);
930
+ }
931
+
932
+ float symmetric_dis (idx_t i, idx_t j) override {
933
+ return compute_code_distance (codes + i * code_size,
934
+ codes + j * code_size);
935
+ }
936
+
937
+ float query_to_code (const uint8_t * code) const {
938
+ return compute_distance (q, code);
939
+ }
940
+
941
+ };
942
+
943
+ #ifdef USE_F16C
944
+
945
+ template<class Quantizer, class Similarity>
946
+ struct DCTemplate<Quantizer, Similarity, 8> : SQDistanceComputer
947
+ {
948
+ using Sim = Similarity;
949
+
950
+ Quantizer quant;
951
+
952
+ DCTemplate(size_t d, const std::vector<float> &trained):
953
+ quant(d, trained)
954
+ {}
955
+
956
+ float compute_distance(const float* x, const uint8_t* code) const {
957
+
958
+ Similarity sim(x);
959
+ sim.begin_8();
960
+ for (size_t i = 0; i < quant.d; i += 8) {
961
+ __m256 xi = quant.reconstruct_8_components(code, i);
962
+ sim.add_8_components(xi);
963
+ }
964
+ return sim.result_8();
965
+ }
966
+
967
+ float compute_code_distance(const uint8_t* code1, const uint8_t* code2)
968
+ const {
969
+ Similarity sim(nullptr);
970
+ sim.begin_8();
971
+ for (size_t i = 0; i < quant.d; i += 8) {
972
+ __m256 x1 = quant.reconstruct_8_components(code1, i);
973
+ __m256 x2 = quant.reconstruct_8_components(code2, i);
974
+ sim.add_8_components_2(x1, x2);
975
+ }
976
+ return sim.result_8();
977
+ }
978
+
979
+ void set_query (const float *x) final {
980
+ q = x;
981
+ }
982
+
983
+ /// compute distance of vector i to current query
984
+ float operator () (idx_t i) final {
985
+ return compute_distance (q, codes + i * code_size);
986
+ }
987
+
988
+ float symmetric_dis (idx_t i, idx_t j) override {
989
+ return compute_code_distance (codes + i * code_size,
990
+ codes + j * code_size);
991
+ }
992
+
993
+ float query_to_code (const uint8_t * code) const {
994
+ return compute_distance (q, code);
995
+ }
996
+
997
+ };
998
+
999
+ #endif
1000
+
1001
+
1002
+
1003
+ /*******************************************************************
1004
+ * DistanceComputerByte: computes distances in the integer domain
1005
+ *******************************************************************/
1006
+
1007
+ template<class Similarity, int SIMDWIDTH>
1008
+ struct DistanceComputerByte : SQDistanceComputer {};
1009
+
1010
+ template<class Similarity>
1011
+ struct DistanceComputerByte<Similarity, 1> : SQDistanceComputer {
1012
+ using Sim = Similarity;
1013
+
1014
+ int d;
1015
+ std::vector<uint8_t> tmp;
1016
+
1017
+ DistanceComputerByte(int d, const std::vector<float> &): d(d), tmp(d) {
1018
+ }
1019
+
1020
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1021
+ const {
1022
+ int accu = 0;
1023
+ for (int i = 0; i < d; i++) {
1024
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1025
+ accu += int(code1[i]) * code2[i];
1026
+ } else {
1027
+ int diff = int(code1[i]) - code2[i];
1028
+ accu += diff * diff;
1029
+ }
1030
+ }
1031
+ return accu;
1032
+ }
1033
+
1034
+ void set_query (const float *x) final {
1035
+ for (int i = 0; i < d; i++) {
1036
+ tmp[i] = int(x[i]);
1037
+ }
1038
+ }
1039
+
1040
+ int compute_distance(const float* x, const uint8_t* code) {
1041
+ set_query(x);
1042
+ return compute_code_distance(tmp.data(), code);
1043
+ }
1044
+
1045
+ /// compute distance of vector i to current query
1046
+ float operator () (idx_t i) final {
1047
+ return compute_distance (q, codes + i * code_size);
1048
+ }
1049
+
1050
+ float symmetric_dis (idx_t i, idx_t j) override {
1051
+ return compute_code_distance (codes + i * code_size,
1052
+ codes + j * code_size);
1053
+ }
1054
+
1055
+ float query_to_code (const uint8_t * code) const {
1056
+ return compute_code_distance (tmp.data(), code);
1057
+ }
1058
+
1059
+ };
1060
+
1061
+ #ifdef USE_AVX
1062
+
1063
+
1064
+ template<class Similarity>
1065
+ struct DistanceComputerByte<Similarity, 8> : SQDistanceComputer {
1066
+ using Sim = Similarity;
1067
+
1068
+ int d;
1069
+ std::vector<uint8_t> tmp;
1070
+
1071
+ DistanceComputerByte(int d, const std::vector<float> &): d(d), tmp(d) {
1072
+ }
1073
+
1074
+ int compute_code_distance(const uint8_t* code1, const uint8_t* code2)
1075
+ const {
1076
+ // __m256i accu = _mm256_setzero_ps ();
1077
+ __m256i accu = _mm256_setzero_si256 ();
1078
+ for (int i = 0; i < d; i += 16) {
1079
+ // load 16 bytes, convert to 16 uint16_t
1080
+ __m256i c1 = _mm256_cvtepu8_epi16
1081
+ (_mm_loadu_si128((__m128i*)(code1 + i)));
1082
+ __m256i c2 = _mm256_cvtepu8_epi16
1083
+ (_mm_loadu_si128((__m128i*)(code2 + i)));
1084
+ __m256i prod32;
1085
+ if (Sim::metric_type == METRIC_INNER_PRODUCT) {
1086
+ prod32 = _mm256_madd_epi16(c1, c2);
1087
+ } else {
1088
+ __m256i diff = _mm256_sub_epi16(c1, c2);
1089
+ prod32 = _mm256_madd_epi16(diff, diff);
1090
+ }
1091
+ accu = _mm256_add_epi32 (accu, prod32);
1092
+
1093
+ }
1094
+ __m128i sum = _mm256_extractf128_si256(accu, 0);
1095
+ sum = _mm_add_epi32 (sum, _mm256_extractf128_si256(accu, 1));
1096
+ sum = _mm_hadd_epi32 (sum, sum);
1097
+ sum = _mm_hadd_epi32 (sum, sum);
1098
+ return _mm_cvtsi128_si32 (sum);
1099
+ }
1100
+
1101
+ void set_query (const float *x) final {
1102
+ /*
1103
+ for (int i = 0; i < d; i += 8) {
1104
+ __m256 xi = _mm256_loadu_ps (x + i);
1105
+ __m256i ci = _mm256_cvtps_epi32(xi);
1106
+ */
1107
+ for (int i = 0; i < d; i++) {
1108
+ tmp[i] = int(x[i]);
1109
+ }
1110
+ }
1111
+
1112
+ int compute_distance(const float* x, const uint8_t* code) {
1113
+ set_query(x);
1114
+ return compute_code_distance(tmp.data(), code);
1115
+ }
1116
+
1117
+ /// compute distance of vector i to current query
1118
+ float operator () (idx_t i) final {
1119
+ return compute_distance (q, codes + i * code_size);
1120
+ }
1121
+
1122
+ float symmetric_dis (idx_t i, idx_t j) override {
1123
+ return compute_code_distance (codes + i * code_size,
1124
+ codes + j * code_size);
1125
+ }
1126
+
1127
+ float query_to_code (const uint8_t * code) const {
1128
+ return compute_code_distance (tmp.data(), code);
1129
+ }
1130
+
1131
+
1132
+ };
1133
+
1134
+ #endif
1135
+
1136
+ /*******************************************************************
1137
+ * select_distance_computer: runtime selection of template
1138
+ * specialization
1139
+ *******************************************************************/
1140
+
1141
+
1142
+ template<class Sim>
1143
+ SQDistanceComputer *select_distance_computer (
1144
+ QuantizerType qtype,
1145
+ size_t d, const std::vector<float> & trained)
1146
+ {
1147
+ constexpr int SIMDWIDTH = Sim::simdwidth;
1148
+ switch(qtype) {
1149
+ case ScalarQuantizer::QT_8bit_uniform:
1150
+ return new DCTemplate<QuantizerTemplate<Codec8bit, true, SIMDWIDTH>,
1151
+ Sim, SIMDWIDTH>(d, trained);
1152
+
1153
+ case ScalarQuantizer::QT_4bit_uniform:
1154
+ return new DCTemplate<QuantizerTemplate<Codec4bit, true, SIMDWIDTH>,
1155
+ Sim, SIMDWIDTH>(d, trained);
1156
+
1157
+ case ScalarQuantizer::QT_8bit:
1158
+ return new DCTemplate<QuantizerTemplate<Codec8bit, false, SIMDWIDTH>,
1159
+ Sim, SIMDWIDTH>(d, trained);
1160
+
1161
+ case ScalarQuantizer::QT_6bit:
1162
+ return new DCTemplate<QuantizerTemplate<Codec6bit, false, SIMDWIDTH>,
1163
+ Sim, SIMDWIDTH>(d, trained);
1164
+
1165
+ case ScalarQuantizer::QT_4bit:
1166
+ return new DCTemplate<QuantizerTemplate<Codec4bit, false, SIMDWIDTH>,
1167
+ Sim, SIMDWIDTH>(d, trained);
1168
+
1169
+ case ScalarQuantizer::QT_fp16:
1170
+ return new DCTemplate
1171
+ <QuantizerFP16<SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained);
1172
+
1173
+ case ScalarQuantizer::QT_8bit_direct:
1174
+ if (d % 16 == 0) {
1175
+ return new DistanceComputerByte<Sim, SIMDWIDTH>(d, trained);
1176
+ } else {
1177
+ return new DCTemplate
1178
+ <Quantizer8bitDirect<SIMDWIDTH>, Sim, SIMDWIDTH>(d, trained);
1179
+ }
1180
+ }
1181
+ FAISS_THROW_MSG ("unknown qtype");
1182
+ return nullptr;
1183
+ }
1184
+
1185
+
1186
+
1187
+ } // anonymous namespace
1188
+
1189
+
1190
+
1191
+ /*******************************************************************
1192
+ * ScalarQuantizer implementation
1193
+ ********************************************************************/
1194
+
1195
+
1196
+
1197
+ ScalarQuantizer::ScalarQuantizer
1198
+ (size_t d, QuantizerType qtype):
1199
+ qtype (qtype), rangestat(RS_minmax), rangestat_arg(0), d (d)
1200
+ {
1201
+ switch (qtype) {
1202
+ case QT_8bit:
1203
+ case QT_8bit_uniform:
1204
+ case QT_8bit_direct:
1205
+ code_size = d;
1206
+ break;
1207
+ case QT_4bit:
1208
+ case QT_4bit_uniform:
1209
+ code_size = (d + 1) / 2;
1210
+ break;
1211
+ case QT_6bit:
1212
+ code_size = (d * 6 + 7) / 8;
1213
+ break;
1214
+ case QT_fp16:
1215
+ code_size = d * 2;
1216
+ break;
1217
+ }
1218
+
1219
+ }
1220
+
1221
+ ScalarQuantizer::ScalarQuantizer ():
1222
+ qtype(QT_8bit),
1223
+ rangestat(RS_minmax), rangestat_arg(0), d (0), code_size(0)
1224
+ {}
1225
+
1226
+ void ScalarQuantizer::train (size_t n, const float *x)
1227
+ {
1228
+ int bit_per_dim =
1229
+ qtype == QT_4bit_uniform ? 4 :
1230
+ qtype == QT_4bit ? 4 :
1231
+ qtype == QT_6bit ? 6 :
1232
+ qtype == QT_8bit_uniform ? 8 :
1233
+ qtype == QT_8bit ? 8 : -1;
1234
+
1235
+ switch (qtype) {
1236
+ case QT_4bit_uniform: case QT_8bit_uniform:
1237
+ train_Uniform (rangestat, rangestat_arg,
1238
+ n * d, 1 << bit_per_dim, x, trained);
1239
+ break;
1240
+ case QT_4bit: case QT_8bit: case QT_6bit:
1241
+ train_NonUniform (rangestat, rangestat_arg,
1242
+ n, d, 1 << bit_per_dim, x, trained);
1243
+ break;
1244
+ case QT_fp16:
1245
+ case QT_8bit_direct:
1246
+ // no training necessary
1247
+ break;
1248
+ }
1249
+ }
1250
+
1251
+ void ScalarQuantizer::train_residual(size_t n,
1252
+ const float *x,
1253
+ Index *quantizer,
1254
+ bool by_residual,
1255
+ bool verbose)
1256
+ {
1257
+ const float * x_in = x;
1258
+
1259
+ // 100k points more than enough
1260
+ x = fvecs_maybe_subsample (
1261
+ d, (size_t*)&n, 100000,
1262
+ x, verbose, 1234);
1263
+
1264
+ ScopeDeleter<float> del_x (x_in == x ? nullptr : x);
1265
+
1266
+ if (by_residual) {
1267
+ std::vector<Index::idx_t> idx(n);
1268
+ quantizer->assign (n, x, idx.data());
1269
+
1270
+ std::vector<float> residuals(n * d);
1271
+ quantizer->compute_residual_n (n, x, residuals.data(), idx.data());
1272
+
1273
+ train (n, residuals.data());
1274
+ } else {
1275
+ train (n, x);
1276
+ }
1277
+ }
1278
+
1279
+
1280
+ ScalarQuantizer::Quantizer *ScalarQuantizer::select_quantizer () const
1281
+ {
1282
+ #ifdef USE_F16C
1283
+ if (d % 8 == 0) {
1284
+ return select_quantizer_1<8> (qtype, d, trained);
1285
+ } else
1286
+ #endif
1287
+ {
1288
+ return select_quantizer_1<1> (qtype, d, trained);
1289
+ }
1290
+ }
1291
+
1292
+
1293
+ void ScalarQuantizer::compute_codes (const float * x,
1294
+ uint8_t * codes,
1295
+ size_t n) const
1296
+ {
1297
+ std::unique_ptr<Quantizer> squant(select_quantizer ());
1298
+
1299
+ memset (codes, 0, code_size * n);
1300
+ #pragma omp parallel for
1301
+ for (size_t i = 0; i < n; i++)
1302
+ squant->encode_vector (x + i * d, codes + i * code_size);
1303
+ }
1304
+
1305
+ void ScalarQuantizer::decode (const uint8_t *codes, float *x, size_t n) const
1306
+ {
1307
+ std::unique_ptr<Quantizer> squant(select_quantizer ());
1308
+
1309
+ #pragma omp parallel for
1310
+ for (size_t i = 0; i < n; i++)
1311
+ squant->decode_vector (codes + i * code_size, x + i * d);
1312
+ }
1313
+
1314
+
1315
+ SQDistanceComputer *
1316
+ ScalarQuantizer::get_distance_computer (MetricType metric) const
1317
+ {
1318
+ FAISS_THROW_IF_NOT(metric == METRIC_L2 || metric == METRIC_INNER_PRODUCT);
1319
+ #ifdef USE_F16C
1320
+ if (d % 8 == 0) {
1321
+ if (metric == METRIC_L2) {
1322
+ return select_distance_computer<SimilarityL2<8> >
1323
+ (qtype, d, trained);
1324
+ } else {
1325
+ return select_distance_computer<SimilarityIP<8> >
1326
+ (qtype, d, trained);
1327
+ }
1328
+ } else
1329
+ #endif
1330
+ {
1331
+ if (metric == METRIC_L2) {
1332
+ return select_distance_computer<SimilarityL2<1> >
1333
+ (qtype, d, trained);
1334
+ } else {
1335
+ return select_distance_computer<SimilarityIP<1> >
1336
+ (qtype, d, trained);
1337
+ }
1338
+ }
1339
+ }
1340
+
1341
+
1342
+ /*******************************************************************
1343
+ * IndexScalarQuantizer/IndexIVFScalarQuantizer scanner object
1344
+ *
1345
+ * It is an InvertedListScanner, but is designed to work with
1346
+ * IndexScalarQuantizer as well.
1347
+ ********************************************************************/
1348
+
1349
+ namespace {
1350
+
1351
+
1352
+ template<class DCClass>
1353
+ struct IVFSQScannerIP: InvertedListScanner {
1354
+ DCClass dc;
1355
+ bool store_pairs, by_residual;
1356
+
1357
+ size_t code_size;
1358
+
1359
+ idx_t list_no; /// current list (set to 0 for Flat index
1360
+ float accu0; /// added to all distances
1361
+
1362
+ IVFSQScannerIP(int d, const std::vector<float> & trained,
1363
+ size_t code_size, bool store_pairs,
1364
+ bool by_residual):
1365
+ dc(d, trained), store_pairs(store_pairs),
1366
+ by_residual(by_residual),
1367
+ code_size(code_size), list_no(0), accu0(0)
1368
+ {}
1369
+
1370
+
1371
+ void set_query (const float *query) override {
1372
+ dc.set_query (query);
1373
+ }
1374
+
1375
+ void set_list (idx_t list_no, float coarse_dis) override {
1376
+ this->list_no = list_no;
1377
+ accu0 = by_residual ? coarse_dis : 0;
1378
+ }
1379
+
1380
+ float distance_to_code (const uint8_t *code) const final {
1381
+ return accu0 + dc.query_to_code (code);
1382
+ }
1383
+
1384
+ size_t scan_codes (size_t list_size,
1385
+ const uint8_t *codes,
1386
+ const idx_t *ids,
1387
+ float *simi, idx_t *idxi,
1388
+ size_t k) const override
1389
+ {
1390
+ size_t nup = 0;
1391
+
1392
+ for (size_t j = 0; j < list_size; j++) {
1393
+
1394
+ float accu = accu0 + dc.query_to_code (codes);
1395
+
1396
+ if (accu > simi [0]) {
1397
+ minheap_pop (k, simi, idxi);
1398
+ int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1399
+ minheap_push (k, simi, idxi, accu, id);
1400
+ nup++;
1401
+ }
1402
+ codes += code_size;
1403
+ }
1404
+ return nup;
1405
+ }
1406
+
1407
+ void scan_codes_range (size_t list_size,
1408
+ const uint8_t *codes,
1409
+ const idx_t *ids,
1410
+ float radius,
1411
+ RangeQueryResult & res) const override
1412
+ {
1413
+ for (size_t j = 0; j < list_size; j++) {
1414
+ float accu = accu0 + dc.query_to_code (codes);
1415
+ if (accu > radius) {
1416
+ int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1417
+ res.add (accu, id);
1418
+ }
1419
+ codes += code_size;
1420
+ }
1421
+ }
1422
+
1423
+
1424
+ };
1425
+
1426
+
1427
+ template<class DCClass>
1428
+ struct IVFSQScannerL2: InvertedListScanner {
1429
+
1430
+ DCClass dc;
1431
+
1432
+ bool store_pairs, by_residual;
1433
+ size_t code_size;
1434
+ const Index *quantizer;
1435
+ idx_t list_no; /// current inverted list
1436
+ const float *x; /// current query
1437
+
1438
+ std::vector<float> tmp;
1439
+
1440
+ IVFSQScannerL2(int d, const std::vector<float> & trained,
1441
+ size_t code_size, const Index *quantizer,
1442
+ bool store_pairs, bool by_residual):
1443
+ dc(d, trained), store_pairs(store_pairs), by_residual(by_residual),
1444
+ code_size(code_size), quantizer(quantizer),
1445
+ list_no (0), x (nullptr), tmp (d)
1446
+ {
1447
+ }
1448
+
1449
+
1450
+ void set_query (const float *query) override {
1451
+ x = query;
1452
+ if (!quantizer) {
1453
+ dc.set_query (query);
1454
+ }
1455
+ }
1456
+
1457
+
1458
+ void set_list (idx_t list_no, float /*coarse_dis*/) override {
1459
+ if (by_residual) {
1460
+ this->list_no = list_no;
1461
+ // shift of x_in wrt centroid
1462
+ quantizer->compute_residual (x, tmp.data(), list_no);
1463
+ dc.set_query (tmp.data ());
1464
+ } else {
1465
+ dc.set_query (x);
1466
+ }
1467
+ }
1468
+
1469
+ float distance_to_code (const uint8_t *code) const final {
1470
+ return dc.query_to_code (code);
1471
+ }
1472
+
1473
+ size_t scan_codes (size_t list_size,
1474
+ const uint8_t *codes,
1475
+ const idx_t *ids,
1476
+ float *simi, idx_t *idxi,
1477
+ size_t k) const override
1478
+ {
1479
+ size_t nup = 0;
1480
+ for (size_t j = 0; j < list_size; j++) {
1481
+
1482
+ float dis = dc.query_to_code (codes);
1483
+
1484
+ if (dis < simi [0]) {
1485
+ maxheap_pop (k, simi, idxi);
1486
+ int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1487
+ maxheap_push (k, simi, idxi, dis, id);
1488
+ nup++;
1489
+ }
1490
+ codes += code_size;
1491
+ }
1492
+ return nup;
1493
+ }
1494
+
1495
+ void scan_codes_range (size_t list_size,
1496
+ const uint8_t *codes,
1497
+ const idx_t *ids,
1498
+ float radius,
1499
+ RangeQueryResult & res) const override
1500
+ {
1501
+ for (size_t j = 0; j < list_size; j++) {
1502
+ float dis = dc.query_to_code (codes);
1503
+ if (dis < radius) {
1504
+ int64_t id = store_pairs ? (list_no << 32 | j) : ids[j];
1505
+ res.add (dis, id);
1506
+ }
1507
+ codes += code_size;
1508
+ }
1509
+ }
1510
+
1511
+
1512
+ };
1513
+
1514
+ template<class DCClass>
1515
+ InvertedListScanner* sel2_InvertedListScanner
1516
+ (const ScalarQuantizer *sq,
1517
+ const Index *quantizer, bool store_pairs, bool r)
1518
+ {
1519
+ if (DCClass::Sim::metric_type == METRIC_L2) {
1520
+ return new IVFSQScannerL2<DCClass>(sq->d, sq->trained, sq->code_size,
1521
+ quantizer, store_pairs, r);
1522
+ } else if (DCClass::Sim::metric_type == METRIC_INNER_PRODUCT) {
1523
+ return new IVFSQScannerIP<DCClass>(sq->d, sq->trained, sq->code_size,
1524
+ store_pairs, r);
1525
+ } else {
1526
+ FAISS_THROW_MSG("unsupported metric type");
1527
+ }
1528
+ }
1529
+
1530
+ template<class Similarity, class Codec, bool uniform>
1531
+ InvertedListScanner* sel12_InvertedListScanner
1532
+ (const ScalarQuantizer *sq,
1533
+ const Index *quantizer, bool store_pairs, bool r)
1534
+ {
1535
+ constexpr int SIMDWIDTH = Similarity::simdwidth;
1536
+ using QuantizerClass = QuantizerTemplate<Codec, uniform, SIMDWIDTH>;
1537
+ using DCClass = DCTemplate<QuantizerClass, Similarity, SIMDWIDTH>;
1538
+ return sel2_InvertedListScanner<DCClass> (sq, quantizer, store_pairs, r);
1539
+ }
1540
+
1541
+
1542
+
1543
+ template<class Similarity>
1544
+ InvertedListScanner* sel1_InvertedListScanner
1545
+ (const ScalarQuantizer *sq, const Index *quantizer,
1546
+ bool store_pairs, bool r)
1547
+ {
1548
+ constexpr int SIMDWIDTH = Similarity::simdwidth;
1549
+ switch(sq->qtype) {
1550
+ case ScalarQuantizer::QT_8bit_uniform:
1551
+ return sel12_InvertedListScanner
1552
+ <Similarity, Codec8bit, true>(sq, quantizer, store_pairs, r);
1553
+ case ScalarQuantizer::QT_4bit_uniform:
1554
+ return sel12_InvertedListScanner
1555
+ <Similarity, Codec4bit, true>(sq, quantizer, store_pairs, r);
1556
+ case ScalarQuantizer::QT_8bit:
1557
+ return sel12_InvertedListScanner
1558
+ <Similarity, Codec8bit, false>(sq, quantizer, store_pairs, r);
1559
+ case ScalarQuantizer::QT_4bit:
1560
+ return sel12_InvertedListScanner
1561
+ <Similarity, Codec4bit, false>(sq, quantizer, store_pairs, r);
1562
+ case ScalarQuantizer::QT_6bit:
1563
+ return sel12_InvertedListScanner
1564
+ <Similarity, Codec6bit, false>(sq, quantizer, store_pairs, r);
1565
+ case ScalarQuantizer::QT_fp16:
1566
+ return sel2_InvertedListScanner
1567
+ <DCTemplate<QuantizerFP16<SIMDWIDTH>, Similarity, SIMDWIDTH> >
1568
+ (sq, quantizer, store_pairs, r);
1569
+ case ScalarQuantizer::QT_8bit_direct:
1570
+ if (sq->d % 16 == 0) {
1571
+ return sel2_InvertedListScanner
1572
+ <DistanceComputerByte<Similarity, SIMDWIDTH> >
1573
+ (sq, quantizer, store_pairs, r);
1574
+ } else {
1575
+ return sel2_InvertedListScanner
1576
+ <DCTemplate<Quantizer8bitDirect<SIMDWIDTH>,
1577
+ Similarity, SIMDWIDTH> >
1578
+ (sq, quantizer, store_pairs, r);
1579
+ }
1580
+
1581
+ }
1582
+
1583
+ FAISS_THROW_MSG ("unknown qtype");
1584
+ return nullptr;
1585
+ }
1586
+
1587
+ template<int SIMDWIDTH>
1588
+ InvertedListScanner* sel0_InvertedListScanner
1589
+ (MetricType mt, const ScalarQuantizer *sq,
1590
+ const Index *quantizer, bool store_pairs, bool by_residual)
1591
+ {
1592
+ if (mt == METRIC_L2) {
1593
+ return sel1_InvertedListScanner<SimilarityL2<SIMDWIDTH> >
1594
+ (sq, quantizer, store_pairs, by_residual);
1595
+ } else if (mt == METRIC_INNER_PRODUCT) {
1596
+ return sel1_InvertedListScanner<SimilarityIP<SIMDWIDTH> >
1597
+ (sq, quantizer, store_pairs, by_residual);
1598
+ } else {
1599
+ FAISS_THROW_MSG("unsupported metric type");
1600
+ }
1601
+ }
1602
+
1603
+
1604
+
1605
+ } // anonymous namespace
1606
+
1607
+
1608
+ InvertedListScanner* ScalarQuantizer::select_InvertedListScanner
1609
+ (MetricType mt, const Index *quantizer,
1610
+ bool store_pairs, bool by_residual) const
1611
+ {
1612
+ #ifdef USE_F16C
1613
+ if (d % 8 == 0) {
1614
+ return sel0_InvertedListScanner<8>
1615
+ (mt, this, quantizer, store_pairs, by_residual);
1616
+ } else
1617
+ #endif
1618
+ {
1619
+ return sel0_InvertedListScanner<1>
1620
+ (mt, this, quantizer, store_pairs, by_residual);
1621
+ }
1622
+ }
1623
+
1624
+
1625
+
1626
+
1627
+
1628
+ } // namespace faiss