faiss 0.2.3 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +4 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/Clustering.cpp +32 -0
  12. data/vendor/faiss/faiss/Clustering.h +14 -0
  13. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  14. data/vendor/faiss/faiss/IVFlib.h +26 -2
  15. data/vendor/faiss/faiss/Index.cpp +36 -3
  16. data/vendor/faiss/faiss/Index.h +43 -6
  17. data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
  18. data/vendor/faiss/faiss/Index2Layer.h +8 -17
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
  21. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  22. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  23. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  24. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  25. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  26. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  27. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  28. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  30. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  31. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  32. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  33. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  34. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  35. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  36. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  37. data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
  38. data/vendor/faiss/faiss/IndexFlat.h +16 -19
  39. data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
  40. data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
  41. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  42. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  43. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  44. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  45. data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
  46. data/vendor/faiss/faiss/IndexIVF.h +59 -22
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
  49. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  50. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  51. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  52. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  53. data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
  54. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  55. data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
  56. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
  57. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  58. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  59. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  60. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  61. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
  62. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
  63. data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
  64. data/vendor/faiss/faiss/IndexLSH.h +4 -16
  65. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  66. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  67. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
  68. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  69. data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
  70. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  71. data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
  72. data/vendor/faiss/faiss/IndexPQ.h +21 -22
  73. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  74. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  75. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  76. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  77. data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
  78. data/vendor/faiss/faiss/IndexRefine.h +14 -2
  79. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  80. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  81. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  82. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  83. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
  84. data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
  85. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  86. data/vendor/faiss/faiss/IndexShards.h +2 -1
  87. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  88. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  89. data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
  90. data/vendor/faiss/faiss/VectorTransform.h +25 -4
  91. data/vendor/faiss/faiss/clone_index.cpp +26 -3
  92. data/vendor/faiss/faiss/clone_index.h +3 -0
  93. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  94. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  95. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  102. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  103. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  104. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  105. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
  106. data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
  107. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  108. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  109. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  111. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  112. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  113. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  114. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  115. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  116. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  117. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  118. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  119. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  120. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  121. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  122. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  123. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  124. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
  125. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
  126. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  127. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  128. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  129. data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
  130. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  131. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  132. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  133. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
  134. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
  135. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  136. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  137. data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
  138. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  139. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  140. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  141. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  142. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  143. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  144. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
  145. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
  146. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  147. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
  148. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  149. data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
  150. data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
  151. data/vendor/faiss/faiss/impl/io.cpp +1 -1
  152. data/vendor/faiss/faiss/impl/io_macros.h +20 -0
  153. data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
  154. data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
  155. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  159. data/vendor/faiss/faiss/index_factory.cpp +772 -412
  160. data/vendor/faiss/faiss/index_factory.h +3 -0
  161. data/vendor/faiss/faiss/index_io.h +5 -0
  162. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  163. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  164. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  165. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  166. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  167. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  168. data/vendor/faiss/faiss/utils/distances.cpp +384 -58
  169. data/vendor/faiss/faiss/utils/distances.h +149 -18
  170. data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
  171. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  172. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  173. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  174. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  175. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  176. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  177. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  178. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  179. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  180. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  181. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  182. data/vendor/faiss/faiss/utils/random.h +5 -0
  183. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  184. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  185. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  186. data/vendor/faiss/faiss/utils/utils.h +1 -1
  187. metadata +46 -5
  188. data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
  189. data/vendor/faiss/faiss/IndexResidual.h +0 -152
@@ -0,0 +1,610 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ // quiet the noise
9
+ // clang-format off
10
+
11
+ #include <faiss/IndexAdditiveQuantizer.h>
12
+
13
+ #include <algorithm>
14
+ #include <cmath>
15
+ #include <cstring>
16
+
17
+ #include <faiss/impl/FaissAssert.h>
18
+ #include <faiss/impl/ResidualQuantizer.h>
19
+ #include <faiss/impl/ResultHandler.h>
20
+ #include <faiss/utils/distances.h>
21
+ #include <faiss/utils/extra_distances.h>
22
+ #include <faiss/utils/utils.h>
23
+
24
+
25
+ namespace faiss {
26
+
27
+ /**************************************************************************************
28
+ * IndexAdditiveQuantizer
29
+ **************************************************************************************/
30
+
31
+ IndexAdditiveQuantizer::IndexAdditiveQuantizer(
32
+ idx_t d,
33
+ AdditiveQuantizer* aq,
34
+ MetricType metric):
35
+ IndexFlatCodes(aq->code_size, d, metric), aq(aq)
36
+ {
37
+ FAISS_THROW_IF_NOT(metric == METRIC_INNER_PRODUCT || metric == METRIC_L2);
38
+ }
39
+
40
+
41
+ namespace {
42
+
43
+ /************************************************************
44
+ * DistanceComputer implementation
45
+ ************************************************************/
46
+
47
+ template <class VectorDistance>
48
+ struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
49
+ std::vector<float> tmp;
50
+ const AdditiveQuantizer & aq;
51
+ VectorDistance vd;
52
+ size_t d;
53
+
54
+ AQDistanceComputerDecompress(const IndexAdditiveQuantizer &iaq, VectorDistance vd):
55
+ FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
56
+ tmp(iaq.d * 2),
57
+ aq(*iaq.aq),
58
+ vd(vd),
59
+ d(iaq.d)
60
+ {}
61
+
62
+ const float *q;
63
+ void set_query(const float* x) final {
64
+ q = x;
65
+ }
66
+
67
+ float symmetric_dis(idx_t i, idx_t j) final {
68
+ aq.decode(codes + i * d, tmp.data(), 1);
69
+ aq.decode(codes + j * d, tmp.data() + d, 1);
70
+ return vd(tmp.data(), tmp.data() + d);
71
+ }
72
+
73
+ float distance_to_code(const uint8_t *code) final {
74
+ aq.decode(code, tmp.data(), 1);
75
+ return vd(q, tmp.data());
76
+ }
77
+
78
+ virtual ~AQDistanceComputerDecompress() {}
79
+ };
80
+
81
+
82
+ template<bool is_IP, AdditiveQuantizer::Search_type_t st>
83
+ struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
84
+ std::vector<float> LUT;
85
+ const AdditiveQuantizer & aq;
86
+ size_t d;
87
+
88
+ explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer &iaq):
89
+ FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
90
+ LUT(iaq.aq->total_codebook_size + iaq.d * 2),
91
+ aq(*iaq.aq),
92
+ d(iaq.d)
93
+ {}
94
+
95
+ float bias;
96
+ void set_query(const float* x) final {
97
+ // this is quite sub-optimal for multiple queries
98
+ aq.compute_LUT(1, x, LUT.data());
99
+ if (is_IP) {
100
+ bias = 0;
101
+ } else {
102
+ bias = fvec_norm_L2sqr(x, d);
103
+ }
104
+ }
105
+
106
+ float symmetric_dis(idx_t i, idx_t j) final {
107
+ float *tmp = LUT.data();
108
+ aq.decode(codes + i * d, tmp, 1);
109
+ aq.decode(codes + j * d, tmp + d, 1);
110
+ return fvec_L2sqr(tmp, tmp + d, d);
111
+ }
112
+
113
+ float distance_to_code(const uint8_t *code) final {
114
+ return bias + aq.compute_1_distance_LUT<is_IP, st>(code, LUT.data());
115
+ }
116
+
117
+ virtual ~AQDistanceComputerLUT() {}
118
+ };
119
+
120
+
121
+
122
+ /************************************************************
123
+ * scanning implementation for search
124
+ ************************************************************/
125
+
126
+
127
+ template <class VectorDistance, class ResultHandler>
128
+ void search_with_decompress(
129
+ const IndexAdditiveQuantizer& ir,
130
+ const float* xq,
131
+ VectorDistance& vd,
132
+ ResultHandler& res) {
133
+ const uint8_t* codes = ir.codes.data();
134
+ size_t ntotal = ir.ntotal;
135
+ size_t code_size = ir.code_size;
136
+ const AdditiveQuantizer *aq = ir.aq;
137
+
138
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
139
+
140
+ #pragma omp parallel for if(res.nq > 100)
141
+ for (int64_t q = 0; q < res.nq; q++) {
142
+ SingleResultHandler resi(res);
143
+ resi.begin(q);
144
+ std::vector<float> tmp(ir.d);
145
+ const float* x = xq + ir.d * q;
146
+ for (size_t i = 0; i < ntotal; i++) {
147
+ aq->decode(codes + i * code_size, tmp.data(), 1);
148
+ float dis = vd(x, tmp.data());
149
+ resi.add_result(dis, i);
150
+ }
151
+ resi.end();
152
+ }
153
+ }
154
+
155
+ template<bool is_IP, AdditiveQuantizer::Search_type_t st, class ResultHandler>
156
+ void search_with_LUT(
157
+ const IndexAdditiveQuantizer& ir,
158
+ const float* xq,
159
+ ResultHandler& res)
160
+ {
161
+ const AdditiveQuantizer & aq = *ir.aq;
162
+ const uint8_t* codes = ir.codes.data();
163
+ size_t ntotal = ir.ntotal;
164
+ size_t code_size = aq.code_size;
165
+ size_t nq = res.nq;
166
+ size_t d = ir.d;
167
+
168
+ using SingleResultHandler = typename ResultHandler::SingleResultHandler;
169
+ std::unique_ptr<float []> LUT(new float[nq * aq.total_codebook_size]);
170
+
171
+ aq.compute_LUT(nq, xq, LUT.get());
172
+
173
+ #pragma omp parallel for if(nq > 100)
174
+ for (int64_t q = 0; q < nq; q++) {
175
+ SingleResultHandler resi(res);
176
+ resi.begin(q);
177
+ std::vector<float> tmp(aq.d);
178
+ const float *LUT_q = LUT.get() + aq.total_codebook_size * q;
179
+ float bias = 0;
180
+ if (!is_IP) { // the LUT function returns ||y||^2 - 2 * <x, y>, need to add ||x||^2
181
+ bias = fvec_norm_L2sqr(xq + q * d, d);
182
+ }
183
+ for (size_t i = 0; i < ntotal; i++) {
184
+ float dis = aq.compute_1_distance_LUT<is_IP, st>(
185
+ codes + i * code_size,
186
+ LUT_q
187
+ );
188
+ resi.add_result(dis + bias, i);
189
+ }
190
+ resi.end();
191
+ }
192
+
193
+ }
194
+
195
+
196
+ } // anonymous namespace
197
+
198
+
199
+ FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceComputer() const {
200
+
201
+ if (aq->search_type == AdditiveQuantizer::ST_decompress) {
202
+ if (metric_type == METRIC_L2) {
203
+ using VD = VectorDistance<METRIC_L2>;
204
+ VD vd = {size_t(d), metric_arg};
205
+ return new AQDistanceComputerDecompress<VD>(*this, vd);
206
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
207
+ using VD = VectorDistance<METRIC_INNER_PRODUCT>;
208
+ VD vd = {size_t(d), metric_arg};
209
+ return new AQDistanceComputerDecompress<VD>(*this, vd);
210
+ } else {
211
+ FAISS_THROW_MSG("unsupported metric");
212
+ }
213
+ } else {
214
+ if (metric_type == METRIC_INNER_PRODUCT) {
215
+ return new AQDistanceComputerLUT<true, AdditiveQuantizer::ST_LUT_nonorm>(*this);
216
+ } else {
217
+ switch(aq->search_type) {
218
+ #define DISPATCH(st) \
219
+ case AdditiveQuantizer::st: \
220
+ return new AQDistanceComputerLUT<false, AdditiveQuantizer::st> (*this);\
221
+ break;
222
+ DISPATCH(ST_norm_float)
223
+ DISPATCH(ST_LUT_nonorm)
224
+ DISPATCH(ST_norm_qint8)
225
+ DISPATCH(ST_norm_qint4)
226
+ DISPATCH(ST_norm_cqint4)
227
+ case AdditiveQuantizer::ST_norm_cqint8:
228
+ case AdditiveQuantizer::ST_norm_lsq2x4:
229
+ case AdditiveQuantizer::ST_norm_rq2x4:
230
+ return new AQDistanceComputerLUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this);\
231
+ break;
232
+ #undef DISPATCH
233
+ default:
234
+ FAISS_THROW_FMT("search type %d not supported", aq->search_type);
235
+ }
236
+ }
237
+ }
238
+ }
239
+
240
+
241
+
242
+
243
+ void IndexAdditiveQuantizer::search(
244
+ idx_t n,
245
+ const float* x,
246
+ idx_t k,
247
+ float* distances,
248
+ idx_t* labels,
249
+ const SearchParameters* params) const {
250
+
251
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
252
+
253
+ if (aq->search_type == AdditiveQuantizer::ST_decompress) {
254
+ if (metric_type == METRIC_L2) {
255
+ using VD = VectorDistance<METRIC_L2>;
256
+ VD vd = {size_t(d), metric_arg};
257
+ HeapResultHandler<VD::C> rh(n, distances, labels, k);
258
+ search_with_decompress(*this, x, vd, rh);
259
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
260
+ using VD = VectorDistance<METRIC_INNER_PRODUCT>;
261
+ VD vd = {size_t(d), metric_arg};
262
+ HeapResultHandler<VD::C> rh(n, distances, labels, k);
263
+ search_with_decompress(*this, x, vd, rh);
264
+ }
265
+ } else {
266
+ if (metric_type == METRIC_INNER_PRODUCT) {
267
+ HeapResultHandler<CMin<float, idx_t> > rh(n, distances, labels, k);
268
+ search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm> (*this, x, rh);
269
+ } else {
270
+ HeapResultHandler<CMax<float, idx_t> > rh(n, distances, labels, k);
271
+ switch(aq->search_type) {
272
+ #define DISPATCH(st) \
273
+ case AdditiveQuantizer::st: \
274
+ search_with_LUT<false, AdditiveQuantizer::st> (*this, x, rh);\
275
+ break;
276
+ DISPATCH(ST_norm_float)
277
+ DISPATCH(ST_LUT_nonorm)
278
+ DISPATCH(ST_norm_qint8)
279
+ DISPATCH(ST_norm_qint4)
280
+ DISPATCH(ST_norm_cqint4)
281
+ case AdditiveQuantizer::ST_norm_cqint8:
282
+ case AdditiveQuantizer::ST_norm_lsq2x4:
283
+ case AdditiveQuantizer::ST_norm_rq2x4:
284
+ search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this, x, rh);
285
+ break;
286
+ #undef DISPATCH
287
+ default:
288
+ FAISS_THROW_FMT("search type %d not supported", aq->search_type);
289
+ }
290
+ }
291
+
292
+ }
293
+ }
294
+
295
+ void IndexAdditiveQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
296
+ return aq->compute_codes(x, bytes, n);
297
+ }
298
+
299
+ void IndexAdditiveQuantizer::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
300
+ return aq->decode(bytes, x, n);
301
+ }
302
+
303
+
304
+
305
+
306
+ /**************************************************************************************
307
+ * IndexResidualQuantizer
308
+ **************************************************************************************/
309
+
310
+ IndexResidualQuantizer::IndexResidualQuantizer(
311
+ int d, ///< dimensionality of the input vectors
312
+ size_t M, ///< number of subquantizers
313
+ size_t nbits, ///< number of bit per subvector index
314
+ MetricType metric,
315
+ Search_type_t search_type)
316
+ : IndexResidualQuantizer(d, std::vector<size_t>(M, nbits), metric, search_type) {
317
+ }
318
+
319
+ IndexResidualQuantizer::IndexResidualQuantizer(
320
+ int d,
321
+ const std::vector<size_t>& nbits,
322
+ MetricType metric,
323
+ Search_type_t search_type)
324
+ : IndexAdditiveQuantizer(d, &rq, metric), rq(d, nbits, search_type) {
325
+ code_size = rq.code_size;
326
+ is_trained = false;
327
+ }
328
+
329
+ IndexResidualQuantizer::IndexResidualQuantizer() : IndexResidualQuantizer(0, 0, 0) {}
330
+
331
+ void IndexResidualQuantizer::train(idx_t n, const float* x) {
332
+ rq.train(n, x);
333
+ is_trained = true;
334
+ }
335
+
336
+
337
+ /**************************************************************************************
338
+ * IndexLocalSearchQuantizer
339
+ **************************************************************************************/
340
+
341
+ IndexLocalSearchQuantizer::IndexLocalSearchQuantizer(
342
+ int d,
343
+ size_t M, ///< number of subquantizers
344
+ size_t nbits, ///< number of bit per subvector index
345
+ MetricType metric,
346
+ Search_type_t search_type)
347
+ : IndexAdditiveQuantizer(d, &lsq, metric), lsq(d, M, nbits, search_type) {
348
+ code_size = lsq.code_size;
349
+ is_trained = false;
350
+ }
351
+
352
+ IndexLocalSearchQuantizer::IndexLocalSearchQuantizer() : IndexLocalSearchQuantizer(0, 0, 0) {}
353
+
354
+ void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
355
+ lsq.train(n, x);
356
+ is_trained = true;
357
+ }
358
+
359
+
360
+ /**************************************************************************************
361
+ * IndexProductResidualQuantizer
362
+ **************************************************************************************/
363
+
364
+ IndexProductResidualQuantizer::IndexProductResidualQuantizer(
365
+ int d, ///< dimensionality of the input vectors
366
+ size_t nsplits, ///< number of residual quantizers
367
+ size_t Msub, ///< number of subquantizers per RQ
368
+ size_t nbits, ///< number of bit per subvector index
369
+ MetricType metric,
370
+ Search_type_t search_type)
371
+ : IndexAdditiveQuantizer(d, &prq, metric), prq(d, nsplits, Msub, nbits, search_type) {
372
+ code_size = prq.code_size;
373
+ is_trained = false;
374
+ }
375
+
376
+ IndexProductResidualQuantizer::IndexProductResidualQuantizer()
377
+ : IndexProductResidualQuantizer(0, 0, 0, 0) {}
378
+
379
+ void IndexProductResidualQuantizer::train(idx_t n, const float* x) {
380
+ prq.train(n, x);
381
+ is_trained = true;
382
+ }
383
+
384
+
385
+ /**************************************************************************************
386
+ * IndexProductLocalSearchQuantizer
387
+ **************************************************************************************/
388
+
389
+ IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer(
390
+ int d, ///< dimensionality of the input vectors
391
+ size_t nsplits, ///< number of local search quantizers
392
+ size_t Msub, ///< number of subquantizers per LSQ
393
+ size_t nbits, ///< number of bit per subvector index
394
+ MetricType metric,
395
+ Search_type_t search_type)
396
+ : IndexAdditiveQuantizer(d, &plsq, metric), plsq(d, nsplits, Msub, nbits, search_type) {
397
+ code_size = plsq.code_size;
398
+ is_trained = false;
399
+ }
400
+
401
+ IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer()
402
+ : IndexProductLocalSearchQuantizer(0, 0, 0, 0) {}
403
+
404
+ void IndexProductLocalSearchQuantizer::train(idx_t n, const float* x) {
405
+ plsq.train(n, x);
406
+ is_trained = true;
407
+ }
408
+
409
+
410
+ /**************************************************************************************
411
+ * AdditiveCoarseQuantizer
412
+ **************************************************************************************/
413
+
414
+ AdditiveCoarseQuantizer::AdditiveCoarseQuantizer(
415
+ idx_t d,
416
+ AdditiveQuantizer* aq,
417
+ MetricType metric):
418
+ Index(d, metric), aq(aq)
419
+ {}
420
+
421
+ void AdditiveCoarseQuantizer::add(idx_t, const float*) {
422
+ FAISS_THROW_MSG("not applicable");
423
+ }
424
+
425
+ void AdditiveCoarseQuantizer::reconstruct(idx_t key, float* recons) const {
426
+ aq->decode_64bit(key, recons);
427
+ }
428
+
429
+ void AdditiveCoarseQuantizer::reset() {
430
+ FAISS_THROW_MSG("not applicable");
431
+ }
432
+
433
+
434
+ void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
435
+ if (verbose) {
436
+ printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n", size_t(n));
437
+ }
438
+ size_t norms_size = sizeof(float) << aq->tot_bits;
439
+
440
+ FAISS_THROW_IF_NOT_MSG (
441
+ norms_size <= aq->max_mem_distances,
442
+ "the RCQ norms matrix will become too large, please reduce the number of quantization steps"
443
+ );
444
+
445
+ aq->train(n, x);
446
+ is_trained = true;
447
+ ntotal = (idx_t)1 << aq->tot_bits;
448
+
449
+ if (metric_type == METRIC_L2) {
450
+ if (verbose) {
451
+ printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", size_t(ntotal));
452
+ }
453
+ // this is not necessary for the residualcoarsequantizer when
454
+ // using beam search. We'll see if the memory overhead is too high
455
+ centroid_norms.resize(ntotal);
456
+ aq->compute_centroid_norms(centroid_norms.data());
457
+ }
458
+ }
459
+
460
+ void AdditiveCoarseQuantizer::search(
461
+ idx_t n,
462
+ const float* x,
463
+ idx_t k,
464
+ float* distances,
465
+ idx_t* labels,
466
+ const SearchParameters * params) const {
467
+
468
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
469
+
470
+ if (metric_type == METRIC_INNER_PRODUCT) {
471
+ aq->knn_centroids_inner_product(n, x, k, distances, labels);
472
+ } else if (metric_type == METRIC_L2) {
473
+ FAISS_THROW_IF_NOT(centroid_norms.size() == ntotal);
474
+ aq->knn_centroids_L2(
475
+ n, x, k, distances, labels, centroid_norms.data());
476
+ }
477
+ }
478
+
479
+ /**************************************************************************************
480
+ * ResidualCoarseQuantizer
481
+ **************************************************************************************/
482
+
483
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
484
+ int d, ///< dimensionality of the input vectors
485
+ const std::vector<size_t>& nbits,
486
+ MetricType metric)
487
+ : AdditiveCoarseQuantizer(d, &rq, metric), rq(d, nbits), beam_factor(4.0) {
488
+ FAISS_THROW_IF_NOT(rq.tot_bits <= 63);
489
+ is_trained = false;
490
+ }
491
+
492
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(
493
+ int d,
494
+ size_t M, ///< number of subquantizers
495
+ size_t nbits, ///< number of bit per subvector index
496
+ MetricType metric)
497
+ : ResidualCoarseQuantizer(d, std::vector<size_t>(M, nbits), metric) {}
498
+
499
+ ResidualCoarseQuantizer::ResidualCoarseQuantizer(): ResidualCoarseQuantizer(0, 0, 0) {}
500
+
501
+
502
+
503
+ void ResidualCoarseQuantizer::set_beam_factor(float new_beam_factor) {
504
+ beam_factor = new_beam_factor;
505
+ if (new_beam_factor > 0) {
506
+ FAISS_THROW_IF_NOT(new_beam_factor >= 1.0);
507
+ return;
508
+ } else if (metric_type == METRIC_L2 && ntotal != centroid_norms.size()) {
509
+ if (verbose) {
510
+ printf("AdditiveCoarseQuantizer::train: computing centroid norms for %zd centroids\n", size_t(ntotal));
511
+ }
512
+ centroid_norms.resize(ntotal);
513
+ aq->compute_centroid_norms(centroid_norms.data());
514
+ }
515
+ }
516
+
517
+ void ResidualCoarseQuantizer::search(
518
+ idx_t n,
519
+ const float* x,
520
+ idx_t k,
521
+ float* distances,
522
+ idx_t* labels,
523
+ const SearchParameters * params
524
+ ) const {
525
+
526
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
527
+
528
+ if (beam_factor < 0) {
529
+ AdditiveCoarseQuantizer::search(n, x, k, distances, labels);
530
+ return;
531
+ }
532
+
533
+ int beam_size = int(k * beam_factor);
534
+ if (beam_size > ntotal) {
535
+ beam_size = ntotal;
536
+ }
537
+ size_t memory_per_point = rq.memory_per_point(beam_size);
538
+
539
+ /*
540
+
541
+ printf("mem per point %ld n=%d max_mem_distance=%ld mem_kb=%zd\n",
542
+ memory_per_point, int(n), rq.max_mem_distances, get_mem_usage_kb());
543
+ */
544
+ if (n > 1 && memory_per_point * n > rq.max_mem_distances) {
545
+ // then split queries to reduce temp memory
546
+ idx_t bs = rq.max_mem_distances / memory_per_point;
547
+ if (bs == 0) {
548
+ bs = 1; // otherwise we can't do much
549
+ }
550
+ if (verbose) {
551
+ printf("ResidualCoarseQuantizer::search: run %d searches in batches of size %d\n",
552
+ int(n),
553
+ int(bs));
554
+ }
555
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
556
+ idx_t i1 = std::min(n, i0 + bs);
557
+ search(i1 - i0, x + i0 * d, k, distances + i0 * k, labels + i0 * k);
558
+ InterruptCallback::check();
559
+ }
560
+ return;
561
+ }
562
+
563
+ std::vector<int32_t> codes(beam_size * rq.M * n);
564
+ std::vector<float> beam_distances(n * beam_size);
565
+
566
+ rq.refine_beam(
567
+ n, 1, x, beam_size, codes.data(), nullptr, beam_distances.data());
568
+
569
+ #pragma omp parallel for if (n > 4000)
570
+ for (idx_t i = 0; i < n; i++) {
571
+ memcpy(distances + i * k,
572
+ beam_distances.data() + beam_size * i,
573
+ k * sizeof(distances[0]));
574
+
575
+ const int32_t* codes_i = codes.data() + beam_size * i * rq.M;
576
+ for (idx_t j = 0; j < k; j++) {
577
+ idx_t l = 0;
578
+ int shift = 0;
579
+ for (int m = 0; m < rq.M; m++) {
580
+ l |= (*codes_i++) << shift;
581
+ shift += rq.nbits[m];
582
+ }
583
+ labels[i * k + j] = l;
584
+ }
585
+ }
586
+ }
587
+
588
+ /**************************************************************************************
589
+ * LocalSearchCoarseQuantizer
590
+ **************************************************************************************/
591
+
592
+ LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer(
593
+ int d, ///< dimensionality of the input vectors
594
+ size_t M, ///< number of subquantizers
595
+ size_t nbits, ///< number of bit per subvector index
596
+ MetricType metric)
597
+ : AdditiveCoarseQuantizer(d, &lsq, metric), lsq(d, M, nbits) {
598
+ FAISS_THROW_IF_NOT(lsq.tot_bits <= 63);
599
+ is_trained = false;
600
+ }
601
+
602
+
603
+ LocalSearchCoarseQuantizer::LocalSearchCoarseQuantizer() {
604
+ aq = &lsq;
605
+ }
606
+
607
+
608
+
609
+
610
+ } // namespace faiss