faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -14,17 +14,11 @@
14
14
  #include <omp.h>
15
15
 
16
16
  #include <faiss/impl/FaissAssert.h>
17
- #include <faiss/utils/random.h>
18
- #include <faiss/utils/utils.h>
19
-
20
17
  #include <faiss/impl/pq4_fast_scan.h>
21
- #include <faiss/impl/simd_result_handlers.h>
22
- #include <faiss/utils/quantize_lut.h>
18
+ #include <faiss/utils/utils.h>
23
19
 
24
20
  namespace faiss {
25
21
 
26
- using namespace simd_result_handlers;
27
-
28
22
  inline size_t roundup(size_t a, size_t b) {
29
23
  return (a + b - 1) / b * b;
30
24
  }
@@ -35,37 +29,19 @@ IndexPQFastScan::IndexPQFastScan(
35
29
  size_t nbits,
36
30
  MetricType metric,
37
31
  int bbs)
38
- : Index(d, metric),
39
- pq(d, M, nbits),
40
- bbs(bbs),
41
- ntotal2(0),
42
- M2(roundup(M, 2)) {
43
- FAISS_THROW_IF_NOT(nbits == 4);
44
- is_trained = false;
32
+ : pq(d, M, nbits) {
33
+ init_fastscan(d, M, nbits, metric, bbs);
45
34
  }
46
35
 
47
- IndexPQFastScan::IndexPQFastScan() : bbs(0), ntotal2(0), M2(0) {}
48
-
49
- IndexPQFastScan::IndexPQFastScan(const IndexPQ& orig, int bbs)
50
- : Index(orig.d, orig.metric_type), pq(orig.pq), bbs(bbs) {
51
- FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
36
+ IndexPQFastScan::IndexPQFastScan(const IndexPQ& orig, int bbs) : pq(orig.pq) {
37
+ init_fastscan(orig.d, pq.M, pq.nbits, orig.metric_type, bbs);
52
38
  ntotal = orig.ntotal;
39
+ ntotal2 = roundup(ntotal, bbs);
53
40
  is_trained = orig.is_trained;
54
41
  orig_codes = orig.codes.data();
55
42
 
56
- qbs = 0; // means use default
57
-
58
43
  // pack the codes
59
-
60
- size_t M = pq.M;
61
-
62
- FAISS_THROW_IF_NOT(bbs % 32 == 0);
63
- M2 = roundup(M, 2);
64
- ntotal2 = roundup(ntotal, bbs);
65
-
66
44
  codes.resize(ntotal2 * M2 / 2);
67
-
68
- // printf("M=%d M2=%d code_size=%d\n", M, M2, pq.code_size);
69
45
  pq4_pack_codes(orig.codes.data(), ntotal, M, ntotal2, bbs, M2, codes.get());
70
46
  }
71
47
 
@@ -77,433 +53,22 @@ void IndexPQFastScan::train(idx_t n, const float* x) {
77
53
  is_trained = true;
78
54
  }
79
55
 
80
- void IndexPQFastScan::add(idx_t n, const float* x) {
81
- FAISS_THROW_IF_NOT(is_trained);
82
- AlignedTable<uint8_t> tmp_codes(n * pq.code_size);
83
- pq.compute_codes(x, tmp_codes.get(), n);
84
- ntotal2 = roundup(ntotal + n, bbs);
85
- size_t new_size = ntotal2 * M2 / 2;
86
- size_t old_size = codes.size();
87
- if (new_size > old_size) {
88
- codes.resize(new_size);
89
- memset(codes.get() + old_size, 0, new_size - old_size);
90
- }
91
- pq4_pack_codes_range(
92
- tmp_codes.get(), pq.M, ntotal, ntotal + n, bbs, M2, codes.get());
93
- ntotal += n;
94
- }
95
-
96
- void IndexPQFastScan::reset() {
97
- codes.resize(0);
98
- ntotal = 0;
99
- }
100
-
101
- namespace {
102
-
103
- // from impl/ProductQuantizer.cpp
104
- template <class C, typename dis_t>
105
- void pq_estimators_from_tables_generic(
106
- const ProductQuantizer& pq,
107
- size_t nbits,
108
- const uint8_t* codes,
109
- size_t ncodes,
110
- const dis_t* dis_table,
111
- size_t k,
112
- typename C::T* heap_dis,
113
- int64_t* heap_ids) {
114
- using accu_t = typename C::T;
115
- const size_t M = pq.M;
116
- const size_t ksub = pq.ksub;
117
- for (size_t j = 0; j < ncodes; ++j) {
118
- PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
119
- accu_t dis = 0;
120
- const dis_t* __restrict dt = dis_table;
121
- for (size_t m = 0; m < M; m++) {
122
- uint64_t c = decoder.decode();
123
- dis += dt[c];
124
- dt += ksub;
125
- }
126
-
127
- if (C::cmp(heap_dis[0], dis)) {
128
- heap_pop<C>(k, heap_dis, heap_ids);
129
- heap_push<C>(k, heap_dis, heap_ids, dis, j);
130
- }
131
- }
132
- }
133
-
134
- } // anonymous namespace
135
-
136
- using namespace quantize_lut;
137
-
138
- void IndexPQFastScan::compute_quantized_LUT(
139
- idx_t n,
140
- const float* x,
141
- uint8_t* lut,
142
- float* normalizers) const {
143
- size_t dim12 = pq.ksub * pq.M;
144
- std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
145
- if (metric_type == METRIC_L2) {
146
- pq.compute_distance_tables(n, x, dis_tables.get());
147
- } else {
148
- pq.compute_inner_prod_tables(n, x, dis_tables.get());
149
- }
150
-
151
- for (uint64_t i = 0; i < n; i++) {
152
- round_uint8_per_column(
153
- dis_tables.get() + i * dim12,
154
- pq.M,
155
- pq.ksub,
156
- &normalizers[2 * i],
157
- &normalizers[2 * i + 1]);
158
- }
159
-
160
- for (uint64_t i = 0; i < n; i++) {
161
- const float* t_in = dis_tables.get() + i * dim12;
162
- uint8_t* t_out = lut + i * M2 * pq.ksub;
163
-
164
- for (int j = 0; j < dim12; j++) {
165
- t_out[j] = int(t_in[j]);
166
- }
167
- memset(t_out + dim12, 0, (M2 - pq.M) * pq.ksub);
168
- }
56
+ void IndexPQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
57
+ const {
58
+ pq.compute_codes(x, codes, n);
169
59
  }
170
60
 
171
- /******************************************************************************
172
- * Search driver routine
173
- ******************************************************************************/
174
-
175
- void IndexPQFastScan::search(
176
- idx_t n,
177
- const float* x,
178
- idx_t k,
179
- float* distances,
180
- idx_t* labels) const {
181
- FAISS_THROW_IF_NOT(k > 0);
182
-
61
+ void IndexPQFastScan::compute_float_LUT(float* lut, idx_t n, const float* x)
62
+ const {
183
63
  if (metric_type == METRIC_L2) {
184
- search_dispatch_implem<true>(n, x, k, distances, labels);
64
+ pq.compute_distance_tables(n, x, lut);
185
65
  } else {
186
- search_dispatch_implem<false>(n, x, k, distances, labels);
66
+ pq.compute_inner_prod_tables(n, x, lut);
187
67
  }
188
68
  }
189
69
 
190
- template <bool is_max>
191
- void IndexPQFastScan::search_dispatch_implem(
192
- idx_t n,
193
- const float* x,
194
- idx_t k,
195
- float* distances,
196
- idx_t* labels) const {
197
- using Cfloat = typename std::conditional<
198
- is_max,
199
- CMax<float, int64_t>,
200
- CMin<float, int64_t>>::type;
201
-
202
- using C = typename std::
203
- conditional<is_max, CMax<uint16_t, int>, CMin<uint16_t, int>>::type;
204
-
205
- if (n == 0) {
206
- return;
207
- }
208
-
209
- // actual implementation used
210
- int impl = implem;
211
-
212
- if (impl == 0) {
213
- if (bbs == 32) {
214
- impl = 12;
215
- } else {
216
- impl = 14;
217
- }
218
- if (k > 20) {
219
- impl++;
220
- }
221
- }
222
-
223
- if (implem == 1) {
224
- FAISS_THROW_IF_NOT(orig_codes);
225
- FAISS_THROW_IF_NOT(is_max);
226
- float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
227
- pq.search(x, n, orig_codes, ntotal, &res, true);
228
- } else if (implem == 2 || implem == 3 || implem == 4) {
229
- FAISS_THROW_IF_NOT(orig_codes);
230
-
231
- size_t dim12 = pq.ksub * pq.M;
232
- std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
233
- if (is_max) {
234
- pq.compute_distance_tables(n, x, dis_tables.get());
235
- } else {
236
- pq.compute_inner_prod_tables(n, x, dis_tables.get());
237
- }
238
-
239
- std::vector<float> normalizers(n * 2);
240
-
241
- if (implem == 2) {
242
- // default float
243
- } else if (implem == 3 || implem == 4) {
244
- for (uint64_t i = 0; i < n; i++) {
245
- round_uint8_per_column(
246
- dis_tables.get() + i * dim12,
247
- pq.M,
248
- pq.ksub,
249
- &normalizers[2 * i],
250
- &normalizers[2 * i + 1]);
251
- }
252
- }
253
-
254
- for (int64_t i = 0; i < n; i++) {
255
- int64_t* heap_ids = labels + i * k;
256
- float* heap_dis = distances + i * k;
257
-
258
- heap_heapify<Cfloat>(k, heap_dis, heap_ids);
259
-
260
- pq_estimators_from_tables_generic<Cfloat>(
261
- pq,
262
- pq.nbits,
263
- orig_codes,
264
- ntotal,
265
- dis_tables.get() + i * dim12,
266
- k,
267
- heap_dis,
268
- heap_ids);
269
-
270
- heap_reorder<Cfloat>(k, heap_dis, heap_ids);
271
-
272
- if (implem == 4) {
273
- float a = normalizers[2 * i];
274
- float b = normalizers[2 * i + 1];
275
-
276
- for (int j = 0; j < k; j++) {
277
- heap_dis[j] = heap_dis[j] / a + b;
278
- }
279
- }
280
- }
281
- } else if (impl >= 12 && impl <= 15) {
282
- FAISS_THROW_IF_NOT(ntotal < INT_MAX);
283
- int nt = std::min(omp_get_max_threads(), int(n));
284
- if (nt < 2) {
285
- if (impl == 12 || impl == 13) {
286
- search_implem_12<C>(n, x, k, distances, labels, impl);
287
- } else {
288
- search_implem_14<C>(n, x, k, distances, labels, impl);
289
- }
290
- } else {
291
- // explicitly slice over threads
292
- #pragma omp parallel for num_threads(nt)
293
- for (int slice = 0; slice < nt; slice++) {
294
- idx_t i0 = n * slice / nt;
295
- idx_t i1 = n * (slice + 1) / nt;
296
- float* dis_i = distances + i0 * k;
297
- idx_t* lab_i = labels + i0 * k;
298
- if (impl == 12 || impl == 13) {
299
- search_implem_12<C>(
300
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
301
- } else {
302
- search_implem_14<C>(
303
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
304
- }
305
- }
306
- }
307
- } else {
308
- FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
309
- }
310
- }
311
-
312
- template <class C>
313
- void IndexPQFastScan::search_implem_12(
314
- idx_t n,
315
- const float* x,
316
- idx_t k,
317
- float* distances,
318
- idx_t* labels,
319
- int impl) const {
320
- FAISS_THROW_IF_NOT(bbs == 32);
321
-
322
- // handle qbs2 blocking by recursive call
323
- int64_t qbs2 = this->qbs == 0 ? 11 : pq4_qbs_to_nq(this->qbs);
324
- if (n > qbs2) {
325
- for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
326
- int64_t i1 = std::min(i0 + qbs2, n);
327
- search_implem_12<C>(
328
- i1 - i0,
329
- x + d * i0,
330
- k,
331
- distances + i0 * k,
332
- labels + i0 * k,
333
- impl);
334
- }
335
- return;
336
- }
337
-
338
- size_t dim12 = pq.ksub * M2;
339
- AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
340
- std::unique_ptr<float[]> normalizers(new float[2 * n]);
341
-
342
- if (skip & 1) {
343
- quantized_dis_tables.clear();
344
- } else {
345
- compute_quantized_LUT(
346
- n, x, quantized_dis_tables.get(), normalizers.get());
347
- }
348
-
349
- AlignedTable<uint8_t> LUT(n * dim12);
350
-
351
- // block sizes are encoded in qbs, 4 bits at a time
352
-
353
- // caution: we override an object field
354
- int qbs = this->qbs;
355
-
356
- if (n != pq4_qbs_to_nq(qbs)) {
357
- qbs = pq4_preferred_qbs(n);
358
- }
359
-
360
- int LUT_nq =
361
- pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
362
- FAISS_THROW_IF_NOT(LUT_nq == n);
363
-
364
- if (k == 1) {
365
- SingleResultHandler<C> handler(n, ntotal);
366
- if (skip & 4) {
367
- // pass
368
- } else {
369
- handler.disable = bool(skip & 2);
370
- pq4_accumulate_loop_qbs(
371
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
372
- }
373
-
374
- handler.to_flat_arrays(distances, labels, normalizers.get());
375
-
376
- } else if (impl == 12) {
377
- std::vector<uint16_t> tmp_dis(n * k);
378
- std::vector<int32_t> tmp_ids(n * k);
379
-
380
- if (skip & 4) {
381
- // skip
382
- } else {
383
- HeapHandler<C> handler(
384
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
385
- handler.disable = bool(skip & 2);
386
-
387
- pq4_accumulate_loop_qbs(
388
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
389
-
390
- if (!(skip & 8)) {
391
- handler.to_flat_arrays(distances, labels, normalizers.get());
392
- }
393
- }
394
-
395
- } else { // impl == 13
396
-
397
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
398
- handler.disable = bool(skip & 2);
399
-
400
- if (skip & 4) {
401
- // skip
402
- } else {
403
- pq4_accumulate_loop_qbs(
404
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
405
- }
406
-
407
- if (!(skip & 8)) {
408
- handler.to_flat_arrays(distances, labels, normalizers.get());
409
- }
410
-
411
- FastScan_stats.t0 += handler.times[0];
412
- FastScan_stats.t1 += handler.times[1];
413
- FastScan_stats.t2 += handler.times[2];
414
- FastScan_stats.t3 += handler.times[3];
415
- }
416
- }
417
-
418
- FastScanStats FastScan_stats;
419
-
420
- template <class C>
421
- void IndexPQFastScan::search_implem_14(
422
- idx_t n,
423
- const float* x,
424
- idx_t k,
425
- float* distances,
426
- idx_t* labels,
427
- int impl) const {
428
- FAISS_THROW_IF_NOT(bbs % 32 == 0);
429
-
430
- int qbs2 = qbs == 0 ? 4 : qbs;
431
-
432
- // handle qbs2 blocking by recursive call
433
- if (n > qbs2) {
434
- for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
435
- int64_t i1 = std::min(i0 + qbs2, n);
436
- search_implem_14<C>(
437
- i1 - i0,
438
- x + d * i0,
439
- k,
440
- distances + i0 * k,
441
- labels + i0 * k,
442
- impl);
443
- }
444
- return;
445
- }
446
-
447
- size_t dim12 = pq.ksub * M2;
448
- AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
449
- std::unique_ptr<float[]> normalizers(new float[2 * n]);
450
-
451
- if (skip & 1) {
452
- quantized_dis_tables.clear();
453
- } else {
454
- compute_quantized_LUT(
455
- n, x, quantized_dis_tables.get(), normalizers.get());
456
- }
457
-
458
- AlignedTable<uint8_t> LUT(n * dim12);
459
- pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
460
-
461
- if (k == 1) {
462
- SingleResultHandler<C> handler(n, ntotal);
463
- if (skip & 4) {
464
- // pass
465
- } else {
466
- handler.disable = bool(skip & 2);
467
- pq4_accumulate_loop(
468
- n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
469
- }
470
- handler.to_flat_arrays(distances, labels, normalizers.get());
471
-
472
- } else if (impl == 14) {
473
- std::vector<uint16_t> tmp_dis(n * k);
474
- std::vector<int32_t> tmp_ids(n * k);
475
-
476
- if (skip & 4) {
477
- // skip
478
- } else if (k > 1) {
479
- HeapHandler<C> handler(
480
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
481
- handler.disable = bool(skip & 2);
482
-
483
- pq4_accumulate_loop(
484
- n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
485
-
486
- if (!(skip & 8)) {
487
- handler.to_flat_arrays(distances, labels, normalizers.get());
488
- }
489
- }
490
-
491
- } else { // impl == 15
492
-
493
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
494
- handler.disable = bool(skip & 2);
495
-
496
- if (skip & 4) {
497
- // skip
498
- } else {
499
- pq4_accumulate_loop(
500
- n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
501
- }
502
-
503
- if (!(skip & 8)) {
504
- handler.to_flat_arrays(distances, labels, normalizers.get());
505
- }
506
- }
70
+ void IndexPQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
71
+ pq.decode(bytes, x, n);
507
72
  }
508
73
 
509
74
  } // namespace faiss
@@ -7,6 +7,7 @@
7
7
 
8
8
  #pragma once
9
9
 
10
+ #include <faiss/IndexFastScan.h>
10
11
  #include <faiss/IndexPQ.h>
11
12
  #include <faiss/impl/ProductQuantizer.h>
12
13
  #include <faiss/utils/AlignedTable.h>
@@ -25,27 +26,9 @@ namespace faiss {
25
26
  * 15: no qbs with reservoir accumulator
26
27
  */
27
28
 
28
- struct IndexPQFastScan : Index {
29
+ struct IndexPQFastScan : IndexFastScan {
29
30
  ProductQuantizer pq;
30
31
 
31
- // implementation to select
32
- int implem = 0;
33
- // skip some parts of the computation (for timing)
34
- int skip = 0;
35
-
36
- // size of the kernel
37
- int bbs; // set at build time
38
- int qbs = 0; // query block size 0 = use default
39
-
40
- // packed version of the codes
41
- size_t ntotal2;
42
- size_t M2;
43
-
44
- AlignedTable<uint8_t> codes;
45
-
46
- // this is for testing purposes only (set when initialized by IndexPQ)
47
- const uint8_t* orig_codes = nullptr;
48
-
49
32
  IndexPQFastScan(
50
33
  int d,
51
34
  size_t M,
@@ -53,73 +36,27 @@ struct IndexPQFastScan : Index {
53
36
  MetricType metric = METRIC_L2,
54
37
  int bbs = 32);
55
38
 
56
- IndexPQFastScan();
39
+ IndexPQFastScan() = default;
57
40
 
58
41
  /// build from an existing IndexPQ
59
42
  explicit IndexPQFastScan(const IndexPQ& orig, int bbs = 32);
60
43
 
61
44
  void train(idx_t n, const float* x) override;
62
- void add(idx_t n, const float* x) override;
63
- void reset() override;
64
- void search(
65
- idx_t n,
66
- const float* x,
67
- idx_t k,
68
- float* distances,
69
- idx_t* labels) const override;
70
45
 
71
- // called by search function
72
- void compute_quantized_LUT(
73
- idx_t n,
74
- const float* x,
75
- uint8_t* lut,
76
- float* normalizers) const;
46
+ void compute_codes(uint8_t* codes, idx_t n, const float* x) const override;
77
47
 
78
- template <bool is_max>
79
- void search_dispatch_implem(
80
- idx_t n,
81
- const float* x,
82
- idx_t k,
83
- float* distances,
84
- idx_t* labels) const;
48
+ void compute_float_LUT(float* lut, idx_t n, const float* x) const override;
85
49
 
86
- template <class C>
87
- void search_implem_2(
88
- idx_t n,
89
- const float* x,
90
- idx_t k,
91
- float* distances,
92
- idx_t* labels) const;
93
-
94
- template <class C>
95
- void search_implem_12(
96
- idx_t n,
97
- const float* x,
98
- idx_t k,
99
- float* distances,
100
- idx_t* labels,
101
- int impl) const;
102
-
103
- template <class C>
104
- void search_implem_14(
105
- idx_t n,
106
- const float* x,
107
- idx_t k,
108
- float* distances,
109
- idx_t* labels,
110
- int impl) const;
50
+ /** Decode a set of vectors.
51
+ *
52
+ * NOTE: The codes in the IndexPQFastScan object are non-contiguous.
53
+ * But this method requires a contiguous representation.
54
+ *
55
+ * @param n number of vectors
56
+ * @param bytes input encoded vectors, size n * code_size
57
+ * @param x output vectors, size n * d
58
+ */
59
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
111
60
  };
112
61
 
113
- struct FastScanStats {
114
- uint64_t t0, t1, t2, t3;
115
- FastScanStats() {
116
- reset();
117
- }
118
- void reset() {
119
- memset(this, 0, sizeof(*this));
120
- }
121
- };
122
-
123
- FAISS_API extern FastScanStats FastScan_stats;
124
-
125
62
  } // namespace faiss