faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -0,0 +1,590 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexIVFAdditiveQuantizerFastScan.h>
9
+
10
+ #include <cassert>
11
+ #include <cinttypes>
12
+ #include <cstdio>
13
+
14
+ #include <omp.h>
15
+
16
+ #include <memory>
17
+
18
+ #include <faiss/impl/AuxIndexStructures.h>
19
+ #include <faiss/impl/FaissAssert.h>
20
+ #include <faiss/impl/LookupTableScaler.h>
21
+ #include <faiss/impl/pq4_fast_scan.h>
22
+ #include <faiss/invlists/BlockInvertedLists.h>
23
+ #include <faiss/utils/distances.h>
24
+ #include <faiss/utils/hamming.h>
25
+ #include <faiss/utils/quantize_lut.h>
26
+ #include <faiss/utils/simdlib.h>
27
+ #include <faiss/utils/utils.h>
28
+
29
+ namespace faiss {
30
+
31
+ inline size_t roundup(size_t a, size_t b) {
32
+ return (a + b - 1) / b * b;
33
+ }
34
+
35
+ IndexIVFAdditiveQuantizerFastScan::IndexIVFAdditiveQuantizerFastScan(
36
+ Index* quantizer,
37
+ AdditiveQuantizer* aq,
38
+ size_t d,
39
+ size_t nlist,
40
+ MetricType metric,
41
+ int bbs)
42
+ : IndexIVFFastScan(quantizer, d, nlist, 0, metric) {
43
+ if (aq != nullptr) {
44
+ init(aq, nlist, metric, bbs);
45
+ }
46
+ }
47
+
48
+ void IndexIVFAdditiveQuantizerFastScan::init(
49
+ AdditiveQuantizer* aq,
50
+ size_t nlist,
51
+ MetricType metric,
52
+ int bbs) {
53
+ FAISS_THROW_IF_NOT(aq != nullptr);
54
+ FAISS_THROW_IF_NOT(!aq->nbits.empty());
55
+ FAISS_THROW_IF_NOT(aq->nbits[0] == 4);
56
+ if (metric == METRIC_INNER_PRODUCT) {
57
+ FAISS_THROW_IF_NOT_MSG(
58
+ aq->search_type == AdditiveQuantizer::ST_LUT_nonorm,
59
+ "Search type must be ST_LUT_nonorm for IP metric");
60
+ } else {
61
+ FAISS_THROW_IF_NOT_MSG(
62
+ aq->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
63
+ aq->search_type == AdditiveQuantizer::ST_norm_rq2x4,
64
+ "Search type must be lsq2x4 or rq2x4 for L2 metric");
65
+ }
66
+
67
+ this->aq = aq;
68
+ if (metric_type == METRIC_L2) {
69
+ M = aq->M + 2; // 2x4 bits AQ
70
+ } else {
71
+ M = aq->M;
72
+ }
73
+ init_fastscan(M, 4, nlist, metric, bbs);
74
+
75
+ max_train_points = 1024 * ksub * M;
76
+ by_residual = true;
77
+ }
78
+
79
+ IndexIVFAdditiveQuantizerFastScan::IndexIVFAdditiveQuantizerFastScan(
80
+ const IndexIVFAdditiveQuantizer& orig,
81
+ int bbs)
82
+ : IndexIVFFastScan(
83
+ orig.quantizer,
84
+ orig.d,
85
+ orig.nlist,
86
+ 0,
87
+ orig.metric_type),
88
+ aq(orig.aq) {
89
+ FAISS_THROW_IF_NOT(
90
+ metric_type == METRIC_INNER_PRODUCT || !orig.by_residual);
91
+
92
+ init(aq, nlist, metric_type, bbs);
93
+
94
+ is_trained = orig.is_trained;
95
+ ntotal = orig.ntotal;
96
+ nprobe = orig.nprobe;
97
+
98
+ for (size_t i = 0; i < nlist; i++) {
99
+ size_t nb = orig.invlists->list_size(i);
100
+ size_t nb2 = roundup(nb, bbs);
101
+ AlignedTable<uint8_t> tmp(nb2 * M2 / 2);
102
+ pq4_pack_codes(
103
+ InvertedLists::ScopedCodes(orig.invlists, i).get(),
104
+ nb,
105
+ M,
106
+ nb2,
107
+ bbs,
108
+ M2,
109
+ tmp.get());
110
+ invlists->add_entries(
111
+ i,
112
+ nb,
113
+ InvertedLists::ScopedIds(orig.invlists, i).get(),
114
+ tmp.get());
115
+ }
116
+
117
+ orig_invlists = orig.invlists;
118
+ }
119
+
120
+ IndexIVFAdditiveQuantizerFastScan::IndexIVFAdditiveQuantizerFastScan() {
121
+ bbs = 0;
122
+ M2 = 0;
123
+ aq = nullptr;
124
+
125
+ is_trained = false;
126
+ }
127
+
128
+ IndexIVFAdditiveQuantizerFastScan::~IndexIVFAdditiveQuantizerFastScan() {}
129
+
130
+ /*********************************************************
131
+ * Training
132
+ *********************************************************/
133
+
134
+ void IndexIVFAdditiveQuantizerFastScan::train_residual(
135
+ idx_t n,
136
+ const float* x_in) {
137
+ if (aq->is_trained) {
138
+ return;
139
+ }
140
+
141
+ const int seed = 0x12345;
142
+ size_t nt = n;
143
+ const float* x = fvecs_maybe_subsample(
144
+ d, &nt, max_train_points, x_in, verbose, seed);
145
+ n = nt;
146
+ if (verbose) {
147
+ printf("training additive quantizer on %zd vectors\n", nt);
148
+ }
149
+ aq->verbose = verbose;
150
+
151
+ std::unique_ptr<float[]> del_x;
152
+ if (x != x_in) {
153
+ del_x.reset((float*)x);
154
+ }
155
+
156
+ const float* trainset;
157
+ std::vector<float> residuals(n * d);
158
+ std::vector<idx_t> assign(n);
159
+
160
+ if (by_residual) {
161
+ if (verbose) {
162
+ printf("computing residuals\n");
163
+ }
164
+ quantizer->assign(n, x, assign.data());
165
+ residuals.resize(n * d);
166
+ for (idx_t i = 0; i < n; i++) {
167
+ quantizer->compute_residual(
168
+ x + i * d, residuals.data() + i * d, assign[i]);
169
+ }
170
+ trainset = residuals.data();
171
+ } else {
172
+ trainset = x;
173
+ }
174
+
175
+ if (verbose) {
176
+ printf("training %zdx%zd additive quantizer on "
177
+ "%" PRId64 " vectors in %dD\n",
178
+ aq->M,
179
+ ksub,
180
+ n,
181
+ d);
182
+ }
183
+ aq->verbose = verbose;
184
+ aq->train(n, trainset);
185
+
186
+ // train norm quantizer
187
+ if (by_residual && metric_type == METRIC_L2) {
188
+ std::vector<float> decoded_x(n * d);
189
+ std::vector<uint8_t> x_codes(n * aq->code_size);
190
+ aq->compute_codes(residuals.data(), x_codes.data(), n);
191
+ aq->decode(x_codes.data(), decoded_x.data(), n);
192
+
193
+ // add coarse centroids
194
+ FAISS_THROW_IF_NOT(assign.size() == n);
195
+ std::vector<float> centroid(d);
196
+ for (idx_t i = 0; i < n; i++) {
197
+ auto xi = decoded_x.data() + i * d;
198
+ quantizer->reconstruct(assign[i], centroid.data());
199
+ fvec_add(d, centroid.data(), xi, xi);
200
+ }
201
+
202
+ std::vector<float> norms(n, 0);
203
+ fvec_norms_L2sqr(norms.data(), decoded_x.data(), d, n);
204
+
205
+ // re-train norm tables
206
+ aq->train_norm(n, norms.data());
207
+ }
208
+
209
+ if (metric_type == METRIC_L2) {
210
+ estimate_norm_scale(n, x);
211
+ }
212
+ }
213
+
214
+ void IndexIVFAdditiveQuantizerFastScan::estimate_norm_scale(
215
+ idx_t n,
216
+ const float* x_in) {
217
+ FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
218
+
219
+ constexpr int seed = 0x980903;
220
+ constexpr size_t max_points_estimated = 65536;
221
+ size_t ns = n;
222
+ const float* x = fvecs_maybe_subsample(
223
+ d, &ns, max_points_estimated, x_in, verbose, seed);
224
+ n = ns;
225
+ std::unique_ptr<float[]> del_x;
226
+ if (x != x_in) {
227
+ del_x.reset((float*)x);
228
+ }
229
+
230
+ std::vector<idx_t> coarse_ids(n);
231
+ std::vector<float> coarse_dis(n);
232
+ quantizer->search(n, x, 1, coarse_dis.data(), coarse_ids.data());
233
+
234
+ AlignedTable<float> dis_tables;
235
+ AlignedTable<float> biases;
236
+
237
+ size_t index_nprobe = nprobe;
238
+ nprobe = 1;
239
+ compute_LUT(n, x, coarse_ids.data(), coarse_dis.data(), dis_tables, biases);
240
+ nprobe = index_nprobe;
241
+
242
+ float scale = 0;
243
+
244
+ #pragma omp parallel for reduction(+ : scale)
245
+ for (idx_t i = 0; i < n; i++) {
246
+ const float* lut = dis_tables.get() + i * M * ksub;
247
+ scale += quantize_lut::aq_estimate_norm_scale(M, ksub, 2, lut);
248
+ }
249
+ scale /= n;
250
+ norm_scale = (int)std::roundf(std::max(scale, 1.0f));
251
+
252
+ if (verbose) {
253
+ printf("estimated norm scale: %lf\n", scale);
254
+ printf("rounded norm scale: %d\n", norm_scale);
255
+ }
256
+ }
257
+
258
+ /*********************************************************
259
+ * Code management functions
260
+ *********************************************************/
261
+
262
+ void IndexIVFAdditiveQuantizerFastScan::encode_vectors(
263
+ idx_t n,
264
+ const float* x,
265
+ const idx_t* list_nos,
266
+ uint8_t* codes,
267
+ bool include_listnos) const {
268
+ idx_t bs = 65536;
269
+ if (n > bs) {
270
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
271
+ idx_t i1 = std::min(n, i0 + bs);
272
+ encode_vectors(
273
+ i1 - i0,
274
+ x + i0 * d,
275
+ list_nos + i0,
276
+ codes + i0 * code_size,
277
+ include_listnos);
278
+ }
279
+ return;
280
+ }
281
+
282
+ if (by_residual) {
283
+ std::vector<float> residuals(n * d);
284
+ std::vector<float> centroids(n * d);
285
+
286
+ #pragma omp parallel for if (n > 1000)
287
+ for (idx_t i = 0; i < n; i++) {
288
+ if (list_nos[i] < 0) {
289
+ memset(residuals.data() + i * d, 0, sizeof(residuals[0]) * d);
290
+ } else {
291
+ quantizer->compute_residual(
292
+ x + i * d, residuals.data() + i * d, list_nos[i]);
293
+ }
294
+ }
295
+
296
+ #pragma omp parallel for if (n > 1000)
297
+ for (idx_t i = 0; i < n; i++) {
298
+ auto c = centroids.data() + i * d;
299
+ quantizer->reconstruct(list_nos[i], c);
300
+ }
301
+
302
+ aq->compute_codes_add_centroids(
303
+ residuals.data(), codes, n, centroids.data());
304
+
305
+ } else {
306
+ aq->compute_codes(x, codes, n);
307
+ }
308
+
309
+ if (include_listnos) {
310
+ size_t coarse_size = coarse_code_size();
311
+ for (idx_t i = n - 1; i >= 0; i--) {
312
+ uint8_t* code = codes + i * (coarse_size + code_size);
313
+ memmove(code + coarse_size, codes + i * code_size, code_size);
314
+ encode_listno(list_nos[i], code);
315
+ }
316
+ }
317
+ }
318
+
319
+ /*********************************************************
320
+ * Search functions
321
+ *********************************************************/
322
+
323
+ void IndexIVFAdditiveQuantizerFastScan::search(
324
+ idx_t n,
325
+ const float* x,
326
+ idx_t k,
327
+ float* distances,
328
+ idx_t* labels,
329
+ const SearchParameters* params) const {
330
+ FAISS_THROW_IF_NOT_MSG(
331
+ !params, "search params not supported for this index");
332
+
333
+ FAISS_THROW_IF_NOT(k > 0);
334
+ bool rescale = (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2);
335
+ if (!rescale) {
336
+ IndexIVFFastScan::search(n, x, k, distances, labels);
337
+ return;
338
+ }
339
+
340
+ NormTableScaler scaler(norm_scale);
341
+ if (metric_type == METRIC_L2) {
342
+ search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
343
+ } else {
344
+ search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
345
+ }
346
+ }
347
+
348
+ /*********************************************************
349
+ * Look-Up Table functions
350
+ *********************************************************/
351
+
352
+ /********************************************************
353
+
354
+ Let q denote the query vector,
355
+ x denote the quantized database vector,
356
+ c denote the corresponding IVF centroid,
357
+ r denote the residual (x - c).
358
+
359
+ The L2 distance between q and x is:
360
+
361
+ d(q, x) = (q - x)^2
362
+ = (q - c - r)^2
363
+ = q^2 - 2<q, c> - 2<q, r> + x^2
364
+
365
+ where q^2 is a constant for all x, <q,c> is only relevant to c,
366
+ and x^2 is the quantized database vector norm.
367
+
368
+ Different from IVFAdditiveQuantizer, we encode the quantized vector norm x^2
369
+ instead of r^2. So that we only need to compute one LUT for each query vector:
370
+
371
+ LUT[m][k] = -2 * <q, codebooks[m][k]>
372
+
373
+ `-2<q,c>` could be precomputed in `compute_LUT` and store in `biases`.
374
+ if `by_residual=False`, `<q,c>` is simply 0.
375
+
376
+
377
+
378
+ About norm look-up tables:
379
+
380
+ To take advantage of the fast SIMD table lookups, we encode the norm by a 2x4
381
+ bits 1D additive quantizer (simply treat the scalar norm as a 1D vector).
382
+
383
+ Let `cm` denote the codebooks of the trained 2x4 bits 1D additive quantizer,
384
+ size (2, 16); `bm` denote the encoding code of the norm, a 8-bit integer; `cb`
385
+ denote the codebooks of the additive quantizer to encode the database vector,
386
+ size (M, 16).
387
+
388
+ The decoded norm is:
389
+
390
+ decoded_norm = cm[0][bm & 15] + cm[1][bm >> 4]
391
+
392
+ The decoding is actually doing a table look-up.
393
+
394
+ We combine the norm LUTs and the IP LUTs together:
395
+
396
+ LUT is a 2D table, size (M + 2, 16)
397
+ if m < M :
398
+ LUT[m][k] = -2 * <q, cb[m][k]>
399
+ else:
400
+ LUT[m][k] = cm[m - M][k]
401
+
402
+ ********************************************************/
403
+
404
+ bool IndexIVFAdditiveQuantizerFastScan::lookup_table_is_3d() const {
405
+ return false;
406
+ }
407
+
408
+ void IndexIVFAdditiveQuantizerFastScan::compute_LUT(
409
+ size_t n,
410
+ const float* x,
411
+ const idx_t* coarse_ids,
412
+ const float*,
413
+ AlignedTable<float>& dis_tables,
414
+ AlignedTable<float>& biases) const {
415
+ const size_t dim12 = ksub * M;
416
+ const size_t ip_dim12 = aq->M * ksub;
417
+
418
+ dis_tables.resize(n * dim12);
419
+
420
+ float coef = 1.0f;
421
+ if (metric_type == METRIC_L2) {
422
+ coef = -2.0f;
423
+ }
424
+
425
+ if (by_residual) {
426
+ // bias = coef * <q, c>
427
+ // NOTE: q^2 is not added to `biases`
428
+ biases.resize(n * nprobe);
429
+ #pragma omp parallel
430
+ {
431
+ std::vector<float> centroid(d);
432
+ float* c = centroid.data();
433
+
434
+ #pragma omp for
435
+ for (idx_t ij = 0; ij < n * nprobe; ij++) {
436
+ int i = ij / nprobe;
437
+ quantizer->reconstruct(coarse_ids[ij], c);
438
+ biases[ij] = coef * fvec_inner_product(c, x + i * d, d);
439
+ }
440
+ }
441
+ }
442
+
443
+ if (metric_type == METRIC_L2) {
444
+ const size_t norm_dim12 = 2 * ksub;
445
+
446
+ // inner product look-up tables
447
+ aq->compute_LUT(n, x, dis_tables.data(), -2.0f, dim12);
448
+
449
+ // copy and rescale norm look-up tables
450
+ auto norm_tabs = aq->norm_tabs;
451
+ if (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2) {
452
+ for (size_t i = 0; i < norm_tabs.size(); i++) {
453
+ norm_tabs[i] /= norm_scale;
454
+ }
455
+ }
456
+ const float* norm_lut = norm_tabs.data();
457
+ FAISS_THROW_IF_NOT(norm_tabs.size() == norm_dim12);
458
+
459
+ // combine them
460
+ #pragma omp parallel for if (n > 100)
461
+ for (idx_t i = 0; i < n; i++) {
462
+ float* tab = dis_tables.data() + i * dim12 + ip_dim12;
463
+ memcpy(tab, norm_lut, norm_dim12 * sizeof(*tab));
464
+ }
465
+
466
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
467
+ aq->compute_LUT(n, x, dis_tables.get());
468
+ } else {
469
+ FAISS_THROW_FMT("metric %d not supported", metric_type);
470
+ }
471
+ }
472
+
473
+ void IndexIVFAdditiveQuantizerFastScan::sa_decode(
474
+ idx_t n,
475
+ const uint8_t* bytes,
476
+ float* x) const {
477
+ aq->decode(bytes, x, n);
478
+ }
479
+
480
+ /********** IndexIVFLocalSearchQuantizerFastScan ************/
481
+ IndexIVFLocalSearchQuantizerFastScan::IndexIVFLocalSearchQuantizerFastScan(
482
+ Index* quantizer,
483
+ size_t d,
484
+ size_t nlist,
485
+ size_t M,
486
+ size_t nbits,
487
+ MetricType metric,
488
+ Search_type_t search_type,
489
+ int bbs)
490
+ : IndexIVFAdditiveQuantizerFastScan(
491
+ quantizer,
492
+ nullptr,
493
+ d,
494
+ nlist,
495
+ metric,
496
+ bbs),
497
+ lsq(d, M, nbits, search_type) {
498
+ FAISS_THROW_IF_NOT(nbits == 4);
499
+ init(&lsq, nlist, metric, bbs);
500
+ }
501
+
502
+ IndexIVFLocalSearchQuantizerFastScan::IndexIVFLocalSearchQuantizerFastScan() {
503
+ aq = &lsq;
504
+ }
505
+
506
+ /********** IndexIVFResidualQuantizerFastScan ************/
507
+ IndexIVFResidualQuantizerFastScan::IndexIVFResidualQuantizerFastScan(
508
+ Index* quantizer,
509
+ size_t d,
510
+ size_t nlist,
511
+ size_t M,
512
+ size_t nbits,
513
+ MetricType metric,
514
+ Search_type_t search_type,
515
+ int bbs)
516
+ : IndexIVFAdditiveQuantizerFastScan(
517
+ quantizer,
518
+ nullptr,
519
+ d,
520
+ nlist,
521
+ metric,
522
+ bbs),
523
+ rq(d, M, nbits, search_type) {
524
+ FAISS_THROW_IF_NOT(nbits == 4);
525
+ init(&rq, nlist, metric, bbs);
526
+ }
527
+
528
+ IndexIVFResidualQuantizerFastScan::IndexIVFResidualQuantizerFastScan() {
529
+ aq = &rq;
530
+ }
531
+
532
+ /********** IndexIVFProductLocalSearchQuantizerFastScan ************/
533
+ IndexIVFProductLocalSearchQuantizerFastScan::
534
+ IndexIVFProductLocalSearchQuantizerFastScan(
535
+ Index* quantizer,
536
+ size_t d,
537
+ size_t nlist,
538
+ size_t nsplits,
539
+ size_t Msub,
540
+ size_t nbits,
541
+ MetricType metric,
542
+ Search_type_t search_type,
543
+ int bbs)
544
+ : IndexIVFAdditiveQuantizerFastScan(
545
+ quantizer,
546
+ nullptr,
547
+ d,
548
+ nlist,
549
+ metric,
550
+ bbs),
551
+ plsq(d, nsplits, Msub, nbits, search_type) {
552
+ FAISS_THROW_IF_NOT(nbits == 4);
553
+ init(&plsq, nlist, metric, bbs);
554
+ }
555
+
556
+ IndexIVFProductLocalSearchQuantizerFastScan::
557
+ IndexIVFProductLocalSearchQuantizerFastScan() {
558
+ aq = &plsq;
559
+ }
560
+
561
+ /********** IndexIVFProductResidualQuantizerFastScan ************/
562
+ IndexIVFProductResidualQuantizerFastScan::
563
+ IndexIVFProductResidualQuantizerFastScan(
564
+ Index* quantizer,
565
+ size_t d,
566
+ size_t nlist,
567
+ size_t nsplits,
568
+ size_t Msub,
569
+ size_t nbits,
570
+ MetricType metric,
571
+ Search_type_t search_type,
572
+ int bbs)
573
+ : IndexIVFAdditiveQuantizerFastScan(
574
+ quantizer,
575
+ nullptr,
576
+ d,
577
+ nlist,
578
+ metric,
579
+ bbs),
580
+ prq(d, nsplits, Msub, nbits, search_type) {
581
+ FAISS_THROW_IF_NOT(nbits == 4);
582
+ init(&prq, nlist, metric, bbs);
583
+ }
584
+
585
+ IndexIVFProductResidualQuantizerFastScan::
586
+ IndexIVFProductResidualQuantizerFastScan() {
587
+ aq = &prq;
588
+ }
589
+
590
+ } // namespace faiss