faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -0,0 +1,626 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexFastScan.h>
9
+
10
+ #include <limits.h>
11
+ #include <cassert>
12
+ #include <memory>
13
+
14
+ #include <omp.h>
15
+
16
+ #include <faiss/impl/FaissAssert.h>
17
+ #include <faiss/impl/IDSelector.h>
18
+ #include <faiss/impl/LookupTableScaler.h>
19
+ #include <faiss/impl/ResultHandler.h>
20
+ #include <faiss/utils/distances.h>
21
+ #include <faiss/utils/extra_distances.h>
22
+ #include <faiss/utils/hamming.h>
23
+ #include <faiss/utils/random.h>
24
+ #include <faiss/utils/utils.h>
25
+
26
+ #include <faiss/impl/pq4_fast_scan.h>
27
+ #include <faiss/impl/simd_result_handlers.h>
28
+ #include <faiss/utils/quantize_lut.h>
29
+
30
+ namespace faiss {
31
+
32
+ using namespace simd_result_handlers;
33
+
34
+ inline size_t roundup(size_t a, size_t b) {
35
+ return (a + b - 1) / b * b;
36
+ }
37
+
38
+ void IndexFastScan::init_fastscan(
39
+ int d,
40
+ size_t M,
41
+ size_t nbits,
42
+ MetricType metric,
43
+ int bbs) {
44
+ FAISS_THROW_IF_NOT(nbits == 4);
45
+ FAISS_THROW_IF_NOT(bbs % 32 == 0);
46
+ this->d = d;
47
+ this->M = M;
48
+ this->nbits = nbits;
49
+ this->metric_type = metric;
50
+ this->bbs = bbs;
51
+ ksub = (1 << nbits);
52
+
53
+ code_size = (M * nbits + 7) / 8;
54
+ ntotal = ntotal2 = 0;
55
+ M2 = roundup(M, 2);
56
+ is_trained = false;
57
+ }
58
+
59
+ IndexFastScan::IndexFastScan()
60
+ : bbs(0), M(0), code_size(0), ntotal2(0), M2(0) {}
61
+
62
+ void IndexFastScan::reset() {
63
+ codes.resize(0);
64
+ ntotal = 0;
65
+ }
66
+
67
+ void IndexFastScan::add(idx_t n, const float* x) {
68
+ FAISS_THROW_IF_NOT(is_trained);
69
+
70
+ // do some blocking to avoid excessive allocs
71
+ constexpr idx_t bs = 65536;
72
+ if (n > bs) {
73
+ for (idx_t i0 = 0; i0 < n; i0 += bs) {
74
+ idx_t i1 = std::min(n, i0 + bs);
75
+ if (verbose) {
76
+ printf("IndexFastScan::add %zd/%zd\n", size_t(i1), size_t(n));
77
+ }
78
+ add(i1 - i0, x + i0 * d);
79
+ }
80
+ return;
81
+ }
82
+ InterruptCallback::check();
83
+
84
+ AlignedTable<uint8_t> tmp_codes(n * code_size);
85
+ compute_codes(tmp_codes.get(), n, x);
86
+
87
+ ntotal2 = roundup(ntotal + n, bbs);
88
+ size_t new_size = ntotal2 * M2 / 2; // assume nbits = 4
89
+ size_t old_size = codes.size();
90
+ if (new_size > old_size) {
91
+ codes.resize(new_size);
92
+ memset(codes.get() + old_size, 0, new_size - old_size);
93
+ }
94
+
95
+ pq4_pack_codes_range(
96
+ tmp_codes.get(), M, ntotal, ntotal + n, bbs, M2, codes.get());
97
+
98
+ ntotal += n;
99
+ }
100
+
101
+ size_t IndexFastScan::remove_ids(const IDSelector& sel) {
102
+ idx_t j = 0;
103
+ for (idx_t i = 0; i < ntotal; i++) {
104
+ if (sel.is_member(i)) {
105
+ // should be removed
106
+ } else {
107
+ if (i > j) {
108
+ for (int sq = 0; sq < M; sq++) {
109
+ uint8_t code =
110
+ pq4_get_packed_element(codes.data(), bbs, M, i, sq);
111
+ pq4_set_packed_element(codes.data(), code, bbs, M, j, sq);
112
+ }
113
+ }
114
+ j++;
115
+ }
116
+ }
117
+ size_t nremove = ntotal - j;
118
+ if (nremove > 0) {
119
+ ntotal = j;
120
+ ntotal2 = roundup(ntotal, bbs);
121
+ size_t new_size = ntotal2 * M2 / 2;
122
+ codes.resize(new_size);
123
+ }
124
+ return nremove;
125
+ }
126
+
127
+ void IndexFastScan::check_compatible_for_merge(const Index& otherIndex) const {
128
+ const IndexFastScan* other =
129
+ dynamic_cast<const IndexFastScan*>(&otherIndex);
130
+ FAISS_THROW_IF_NOT(other);
131
+ FAISS_THROW_IF_NOT(other->M == M);
132
+ FAISS_THROW_IF_NOT(other->bbs == bbs);
133
+ FAISS_THROW_IF_NOT(other->d == d);
134
+ FAISS_THROW_IF_NOT(other->code_size == code_size);
135
+ FAISS_THROW_IF_NOT_MSG(
136
+ typeid(*this) == typeid(*other),
137
+ "can only merge indexes of the same type");
138
+ }
139
+
140
+ void IndexFastScan::merge_from(Index& otherIndex, idx_t add_id) {
141
+ check_compatible_for_merge(otherIndex);
142
+ IndexFastScan* other = static_cast<IndexFastScan*>(&otherIndex);
143
+ ntotal2 = roundup(ntotal + other->ntotal, bbs);
144
+ codes.resize(ntotal2 * M2 / 2);
145
+ for (int i = 0; i < other->ntotal; i++) {
146
+ for (int sq = 0; sq < M; sq++) {
147
+ uint8_t code =
148
+ pq4_get_packed_element(other->codes.data(), bbs, M, i, sq);
149
+ pq4_set_packed_element(codes.data(), code, bbs, M, ntotal + i, sq);
150
+ }
151
+ }
152
+ ntotal += other->ntotal;
153
+ other->reset();
154
+ }
155
+
156
+ namespace {
157
+
158
+ template <class C, typename dis_t, class Scaler>
159
+ void estimators_from_tables_generic(
160
+ const IndexFastScan& index,
161
+ const uint8_t* codes,
162
+ size_t ncodes,
163
+ const dis_t* dis_table,
164
+ size_t k,
165
+ typename C::T* heap_dis,
166
+ int64_t* heap_ids,
167
+ const Scaler& scaler) {
168
+ using accu_t = typename C::T;
169
+
170
+ for (size_t j = 0; j < ncodes; ++j) {
171
+ BitstringReader bsr(codes + j * index.code_size, index.code_size);
172
+ accu_t dis = 0;
173
+ const dis_t* dt = dis_table;
174
+ for (size_t m = 0; m < index.M - scaler.nscale; m++) {
175
+ uint64_t c = bsr.read(index.nbits);
176
+ dis += dt[c];
177
+ dt += index.ksub;
178
+ }
179
+
180
+ for (size_t m = 0; m < scaler.nscale; m++) {
181
+ uint64_t c = bsr.read(index.nbits);
182
+ dis += scaler.scale_one(dt[c]);
183
+ dt += index.ksub;
184
+ }
185
+
186
+ if (C::cmp(heap_dis[0], dis)) {
187
+ heap_pop<C>(k, heap_dis, heap_ids);
188
+ heap_push<C>(k, heap_dis, heap_ids, dis, j);
189
+ }
190
+ }
191
+ }
192
+
193
+ } // anonymous namespace
194
+
195
+ using namespace quantize_lut;
196
+
197
+ void IndexFastScan::compute_quantized_LUT(
198
+ idx_t n,
199
+ const float* x,
200
+ uint8_t* lut,
201
+ float* normalizers) const {
202
+ size_t dim12 = ksub * M;
203
+ std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
204
+ compute_float_LUT(dis_tables.get(), n, x);
205
+
206
+ for (uint64_t i = 0; i < n; i++) {
207
+ round_uint8_per_column(
208
+ dis_tables.get() + i * dim12,
209
+ M,
210
+ ksub,
211
+ &normalizers[2 * i],
212
+ &normalizers[2 * i + 1]);
213
+ }
214
+
215
+ for (uint64_t i = 0; i < n; i++) {
216
+ const float* t_in = dis_tables.get() + i * dim12;
217
+ uint8_t* t_out = lut + i * M2 * ksub;
218
+
219
+ for (int j = 0; j < dim12; j++) {
220
+ t_out[j] = int(t_in[j]);
221
+ }
222
+ memset(t_out + dim12, 0, (M2 - M) * ksub);
223
+ }
224
+ }
225
+
226
+ /******************************************************************************
227
+ * Search driver routine
228
+ ******************************************************************************/
229
+
230
+ void IndexFastScan::search(
231
+ idx_t n,
232
+ const float* x,
233
+ idx_t k,
234
+ float* distances,
235
+ idx_t* labels,
236
+ const SearchParameters* params) const {
237
+ FAISS_THROW_IF_NOT_MSG(
238
+ !params, "search params not supported for this index");
239
+ FAISS_THROW_IF_NOT(k > 0);
240
+
241
+ DummyScaler scaler;
242
+ if (metric_type == METRIC_L2) {
243
+ search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
244
+ } else {
245
+ search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
246
+ }
247
+ }
248
+
249
+ template <bool is_max, class Scaler>
250
+ void IndexFastScan::search_dispatch_implem(
251
+ idx_t n,
252
+ const float* x,
253
+ idx_t k,
254
+ float* distances,
255
+ idx_t* labels,
256
+ const Scaler& scaler) const {
257
+ using Cfloat = typename std::conditional<
258
+ is_max,
259
+ CMax<float, int64_t>,
260
+ CMin<float, int64_t>>::type;
261
+
262
+ using C = typename std::
263
+ conditional<is_max, CMax<uint16_t, int>, CMin<uint16_t, int>>::type;
264
+
265
+ if (n == 0) {
266
+ return;
267
+ }
268
+
269
+ // actual implementation used
270
+ int impl = implem;
271
+
272
+ if (impl == 0) {
273
+ if (bbs == 32) {
274
+ impl = 12;
275
+ } else {
276
+ impl = 14;
277
+ }
278
+ if (k > 20) {
279
+ impl++;
280
+ }
281
+ }
282
+
283
+ if (implem == 1) {
284
+ FAISS_THROW_MSG("not implemented");
285
+ } else if (implem == 2 || implem == 3 || implem == 4) {
286
+ FAISS_THROW_IF_NOT(orig_codes != nullptr);
287
+ search_implem_234<Cfloat>(n, x, k, distances, labels, scaler);
288
+ } else if (impl >= 12 && impl <= 15) {
289
+ FAISS_THROW_IF_NOT(ntotal < INT_MAX);
290
+ int nt = std::min(omp_get_max_threads(), int(n));
291
+ if (nt < 2) {
292
+ if (impl == 12 || impl == 13) {
293
+ search_implem_12<C>(n, x, k, distances, labels, impl, scaler);
294
+ } else {
295
+ search_implem_14<C>(n, x, k, distances, labels, impl, scaler);
296
+ }
297
+ } else {
298
+ // explicitly slice over threads
299
+ #pragma omp parallel for num_threads(nt)
300
+ for (int slice = 0; slice < nt; slice++) {
301
+ idx_t i0 = n * slice / nt;
302
+ idx_t i1 = n * (slice + 1) / nt;
303
+ float* dis_i = distances + i0 * k;
304
+ idx_t* lab_i = labels + i0 * k;
305
+ if (impl == 12 || impl == 13) {
306
+ search_implem_12<C>(
307
+ i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler);
308
+ } else {
309
+ search_implem_14<C>(
310
+ i1 - i0, x + i0 * d, k, dis_i, lab_i, impl, scaler);
311
+ }
312
+ }
313
+ }
314
+ } else {
315
+ FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
316
+ }
317
+ }
318
+
319
+ template <class Cfloat, class Scaler>
320
+ void IndexFastScan::search_implem_234(
321
+ idx_t n,
322
+ const float* x,
323
+ idx_t k,
324
+ float* distances,
325
+ idx_t* labels,
326
+ const Scaler& scaler) const {
327
+ FAISS_THROW_IF_NOT(implem == 2 || implem == 3 || implem == 4);
328
+
329
+ const size_t dim12 = ksub * M;
330
+ std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
331
+ compute_float_LUT(dis_tables.get(), n, x);
332
+
333
+ std::vector<float> normalizers(n * 2);
334
+
335
+ if (implem == 2) {
336
+ // default float
337
+ } else if (implem == 3 || implem == 4) {
338
+ for (uint64_t i = 0; i < n; i++) {
339
+ round_uint8_per_column(
340
+ dis_tables.get() + i * dim12,
341
+ M,
342
+ ksub,
343
+ &normalizers[2 * i],
344
+ &normalizers[2 * i + 1]);
345
+ }
346
+ }
347
+
348
+ #pragma omp parallel for if (n > 1000)
349
+ for (int64_t i = 0; i < n; i++) {
350
+ int64_t* heap_ids = labels + i * k;
351
+ float* heap_dis = distances + i * k;
352
+
353
+ heap_heapify<Cfloat>(k, heap_dis, heap_ids);
354
+
355
+ estimators_from_tables_generic<Cfloat>(
356
+ *this,
357
+ orig_codes,
358
+ ntotal,
359
+ dis_tables.get() + i * dim12,
360
+ k,
361
+ heap_dis,
362
+ heap_ids,
363
+ scaler);
364
+
365
+ heap_reorder<Cfloat>(k, heap_dis, heap_ids);
366
+
367
+ if (implem == 4) {
368
+ float a = normalizers[2 * i];
369
+ float b = normalizers[2 * i + 1];
370
+
371
+ for (int j = 0; j < k; j++) {
372
+ heap_dis[j] = heap_dis[j] / a + b;
373
+ }
374
+ }
375
+ }
376
+ }
377
+
378
+ template <class C, class Scaler>
379
+ void IndexFastScan::search_implem_12(
380
+ idx_t n,
381
+ const float* x,
382
+ idx_t k,
383
+ float* distances,
384
+ idx_t* labels,
385
+ int impl,
386
+ const Scaler& scaler) const {
387
+ FAISS_THROW_IF_NOT(bbs == 32);
388
+
389
+ // handle qbs2 blocking by recursive call
390
+ int64_t qbs2 = this->qbs == 0 ? 11 : pq4_qbs_to_nq(this->qbs);
391
+ if (n > qbs2) {
392
+ for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
393
+ int64_t i1 = std::min(i0 + qbs2, n);
394
+ search_implem_12<C>(
395
+ i1 - i0,
396
+ x + d * i0,
397
+ k,
398
+ distances + i0 * k,
399
+ labels + i0 * k,
400
+ impl,
401
+ scaler);
402
+ }
403
+ return;
404
+ }
405
+
406
+ size_t dim12 = ksub * M2;
407
+ AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
408
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
409
+
410
+ if (skip & 1) {
411
+ quantized_dis_tables.clear();
412
+ } else {
413
+ compute_quantized_LUT(
414
+ n, x, quantized_dis_tables.get(), normalizers.get());
415
+ }
416
+
417
+ AlignedTable<uint8_t> LUT(n * dim12);
418
+
419
+ // block sizes are encoded in qbs, 4 bits at a time
420
+
421
+ // caution: we override an object field
422
+ int qbs = this->qbs;
423
+
424
+ if (n != pq4_qbs_to_nq(qbs)) {
425
+ qbs = pq4_preferred_qbs(n);
426
+ }
427
+
428
+ int LUT_nq =
429
+ pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
430
+ FAISS_THROW_IF_NOT(LUT_nq == n);
431
+
432
+ if (k == 1) {
433
+ SingleResultHandler<C> handler(n, ntotal);
434
+ if (skip & 4) {
435
+ // pass
436
+ } else {
437
+ handler.disable = bool(skip & 2);
438
+ pq4_accumulate_loop_qbs(
439
+ qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
440
+ }
441
+
442
+ handler.to_flat_arrays(distances, labels, normalizers.get());
443
+
444
+ } else if (impl == 12) {
445
+ std::vector<uint16_t> tmp_dis(n * k);
446
+ std::vector<int32_t> tmp_ids(n * k);
447
+
448
+ if (skip & 4) {
449
+ // skip
450
+ } else {
451
+ HeapHandler<C> handler(
452
+ n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
453
+ handler.disable = bool(skip & 2);
454
+
455
+ pq4_accumulate_loop_qbs(
456
+ qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
457
+
458
+ if (!(skip & 8)) {
459
+ handler.to_flat_arrays(distances, labels, normalizers.get());
460
+ }
461
+ }
462
+
463
+ } else { // impl == 13
464
+
465
+ ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
466
+ handler.disable = bool(skip & 2);
467
+
468
+ if (skip & 4) {
469
+ // skip
470
+ } else {
471
+ pq4_accumulate_loop_qbs(
472
+ qbs, ntotal2, M2, codes.get(), LUT.get(), handler, scaler);
473
+ }
474
+
475
+ if (!(skip & 8)) {
476
+ handler.to_flat_arrays(distances, labels, normalizers.get());
477
+ }
478
+
479
+ FastScan_stats.t0 += handler.times[0];
480
+ FastScan_stats.t1 += handler.times[1];
481
+ FastScan_stats.t2 += handler.times[2];
482
+ FastScan_stats.t3 += handler.times[3];
483
+ }
484
+ }
485
+
486
+ FastScanStats FastScan_stats;
487
+
488
+ template <class C, class Scaler>
489
+ void IndexFastScan::search_implem_14(
490
+ idx_t n,
491
+ const float* x,
492
+ idx_t k,
493
+ float* distances,
494
+ idx_t* labels,
495
+ int impl,
496
+ const Scaler& scaler) const {
497
+ FAISS_THROW_IF_NOT(bbs % 32 == 0);
498
+
499
+ int qbs2 = qbs == 0 ? 4 : qbs;
500
+
501
+ // handle qbs2 blocking by recursive call
502
+ if (n > qbs2) {
503
+ for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
504
+ int64_t i1 = std::min(i0 + qbs2, n);
505
+ search_implem_14<C>(
506
+ i1 - i0,
507
+ x + d * i0,
508
+ k,
509
+ distances + i0 * k,
510
+ labels + i0 * k,
511
+ impl,
512
+ scaler);
513
+ }
514
+ return;
515
+ }
516
+
517
+ size_t dim12 = ksub * M2;
518
+ AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
519
+ std::unique_ptr<float[]> normalizers(new float[2 * n]);
520
+
521
+ if (skip & 1) {
522
+ quantized_dis_tables.clear();
523
+ } else {
524
+ compute_quantized_LUT(
525
+ n, x, quantized_dis_tables.get(), normalizers.get());
526
+ }
527
+
528
+ AlignedTable<uint8_t> LUT(n * dim12);
529
+ pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
530
+
531
+ if (k == 1) {
532
+ SingleResultHandler<C> handler(n, ntotal);
533
+ if (skip & 4) {
534
+ // pass
535
+ } else {
536
+ handler.disable = bool(skip & 2);
537
+ pq4_accumulate_loop(
538
+ n,
539
+ ntotal2,
540
+ bbs,
541
+ M2,
542
+ codes.get(),
543
+ LUT.get(),
544
+ handler,
545
+ scaler);
546
+ }
547
+ handler.to_flat_arrays(distances, labels, normalizers.get());
548
+
549
+ } else if (impl == 14) {
550
+ std::vector<uint16_t> tmp_dis(n * k);
551
+ std::vector<int32_t> tmp_ids(n * k);
552
+
553
+ if (skip & 4) {
554
+ // skip
555
+ } else if (k > 1) {
556
+ HeapHandler<C> handler(
557
+ n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
558
+ handler.disable = bool(skip & 2);
559
+
560
+ pq4_accumulate_loop(
561
+ n,
562
+ ntotal2,
563
+ bbs,
564
+ M2,
565
+ codes.get(),
566
+ LUT.get(),
567
+ handler,
568
+ scaler);
569
+
570
+ if (!(skip & 8)) {
571
+ handler.to_flat_arrays(distances, labels, normalizers.get());
572
+ }
573
+ }
574
+
575
+ } else { // impl == 15
576
+
577
+ ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
578
+ handler.disable = bool(skip & 2);
579
+
580
+ if (skip & 4) {
581
+ // skip
582
+ } else {
583
+ pq4_accumulate_loop(
584
+ n,
585
+ ntotal2,
586
+ bbs,
587
+ M2,
588
+ codes.get(),
589
+ LUT.get(),
590
+ handler,
591
+ scaler);
592
+ }
593
+
594
+ if (!(skip & 8)) {
595
+ handler.to_flat_arrays(distances, labels, normalizers.get());
596
+ }
597
+ }
598
+ }
599
+
600
+ template void IndexFastScan::search_dispatch_implem<true, NormTableScaler>(
601
+ idx_t n,
602
+ const float* x,
603
+ idx_t k,
604
+ float* distances,
605
+ idx_t* labels,
606
+ const NormTableScaler& scaler) const;
607
+
608
+ template void IndexFastScan::search_dispatch_implem<false, NormTableScaler>(
609
+ idx_t n,
610
+ const float* x,
611
+ idx_t k,
612
+ float* distances,
613
+ idx_t* labels,
614
+ const NormTableScaler& scaler) const;
615
+
616
+ void IndexFastScan::reconstruct(idx_t key, float* recons) const {
617
+ std::vector<uint8_t> code(code_size, 0);
618
+ BitstringWriter bsw(code.data(), code_size);
619
+ for (size_t m = 0; m < M; m++) {
620
+ uint8_t c = pq4_get_packed_element(codes.data(), bbs, M2, key, m);
621
+ bsw.write(c, nbits);
622
+ }
623
+ sa_decode(1, code.data(), recons);
624
+ }
625
+
626
+ } // namespace faiss