faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -14,17 +14,11 @@
14
14
  #include <omp.h>
15
15
 
16
16
  #include <faiss/impl/FaissAssert.h>
17
- #include <faiss/utils/random.h>
18
- #include <faiss/utils/utils.h>
19
-
20
17
  #include <faiss/impl/pq4_fast_scan.h>
21
- #include <faiss/impl/simd_result_handlers.h>
22
- #include <faiss/utils/quantize_lut.h>
18
+ #include <faiss/utils/utils.h>
23
19
 
24
20
  namespace faiss {
25
21
 
26
- using namespace simd_result_handlers;
27
-
28
22
  inline size_t roundup(size_t a, size_t b) {
29
23
  return (a + b - 1) / b * b;
30
24
  }
@@ -35,37 +29,19 @@ IndexPQFastScan::IndexPQFastScan(
35
29
  size_t nbits,
36
30
  MetricType metric,
37
31
  int bbs)
38
- : Index(d, metric),
39
- pq(d, M, nbits),
40
- bbs(bbs),
41
- ntotal2(0),
42
- M2(roundup(M, 2)) {
43
- FAISS_THROW_IF_NOT(nbits == 4);
44
- is_trained = false;
32
+ : pq(d, M, nbits) {
33
+ init_fastscan(d, M, nbits, metric, bbs);
45
34
  }
46
35
 
47
- IndexPQFastScan::IndexPQFastScan() : bbs(0), ntotal2(0), M2(0) {}
48
-
49
- IndexPQFastScan::IndexPQFastScan(const IndexPQ& orig, int bbs)
50
- : Index(orig.d, orig.metric_type), pq(orig.pq), bbs(bbs) {
51
- FAISS_THROW_IF_NOT(orig.pq.nbits == 4);
36
+ IndexPQFastScan::IndexPQFastScan(const IndexPQ& orig, int bbs) : pq(orig.pq) {
37
+ init_fastscan(orig.d, pq.M, pq.nbits, orig.metric_type, bbs);
52
38
  ntotal = orig.ntotal;
39
+ ntotal2 = roundup(ntotal, bbs);
53
40
  is_trained = orig.is_trained;
54
41
  orig_codes = orig.codes.data();
55
42
 
56
- qbs = 0; // means use default
57
-
58
43
  // pack the codes
59
-
60
- size_t M = pq.M;
61
-
62
- FAISS_THROW_IF_NOT(bbs % 32 == 0);
63
- M2 = roundup(M, 2);
64
- ntotal2 = roundup(ntotal, bbs);
65
-
66
44
  codes.resize(ntotal2 * M2 / 2);
67
-
68
- // printf("M=%d M2=%d code_size=%d\n", M, M2, pq.code_size);
69
45
  pq4_pack_codes(orig.codes.data(), ntotal, M, ntotal2, bbs, M2, codes.get());
70
46
  }
71
47
 
@@ -77,433 +53,22 @@ void IndexPQFastScan::train(idx_t n, const float* x) {
77
53
  is_trained = true;
78
54
  }
79
55
 
80
- void IndexPQFastScan::add(idx_t n, const float* x) {
81
- FAISS_THROW_IF_NOT(is_trained);
82
- AlignedTable<uint8_t> tmp_codes(n * pq.code_size);
83
- pq.compute_codes(x, tmp_codes.get(), n);
84
- ntotal2 = roundup(ntotal + n, bbs);
85
- size_t new_size = ntotal2 * M2 / 2;
86
- size_t old_size = codes.size();
87
- if (new_size > old_size) {
88
- codes.resize(new_size);
89
- memset(codes.get() + old_size, 0, new_size - old_size);
90
- }
91
- pq4_pack_codes_range(
92
- tmp_codes.get(), pq.M, ntotal, ntotal + n, bbs, M2, codes.get());
93
- ntotal += n;
94
- }
95
-
96
- void IndexPQFastScan::reset() {
97
- codes.resize(0);
98
- ntotal = 0;
99
- }
100
-
101
- namespace {
102
-
103
- // from impl/ProductQuantizer.cpp
104
- template <class C, typename dis_t>
105
- void pq_estimators_from_tables_generic(
106
- const ProductQuantizer& pq,
107
- size_t nbits,
108
- const uint8_t* codes,
109
- size_t ncodes,
110
- const dis_t* dis_table,
111
- size_t k,
112
- typename C::T* heap_dis,
113
- int64_t* heap_ids) {
114
- using accu_t = typename C::T;
115
- const size_t M = pq.M;
116
- const size_t ksub = pq.ksub;
117
- for (size_t j = 0; j < ncodes; ++j) {
118
- PQDecoderGeneric decoder(codes + j * pq.code_size, nbits);
119
- accu_t dis = 0;
120
- const dis_t* __restrict dt = dis_table;
121
- for (size_t m = 0; m < M; m++) {
122
- uint64_t c = decoder.decode();
123
- dis += dt[c];
124
- dt += ksub;
125
- }
126
-
127
- if (C::cmp(heap_dis[0], dis)) {
128
- heap_pop<C>(k, heap_dis, heap_ids);
129
- heap_push<C>(k, heap_dis, heap_ids, dis, j);
130
- }
131
- }
132
- }
133
-
134
- } // anonymous namespace
135
-
136
- using namespace quantize_lut;
137
-
138
- void IndexPQFastScan::compute_quantized_LUT(
139
- idx_t n,
140
- const float* x,
141
- uint8_t* lut,
142
- float* normalizers) const {
143
- size_t dim12 = pq.ksub * pq.M;
144
- std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
145
- if (metric_type == METRIC_L2) {
146
- pq.compute_distance_tables(n, x, dis_tables.get());
147
- } else {
148
- pq.compute_inner_prod_tables(n, x, dis_tables.get());
149
- }
150
-
151
- for (uint64_t i = 0; i < n; i++) {
152
- round_uint8_per_column(
153
- dis_tables.get() + i * dim12,
154
- pq.M,
155
- pq.ksub,
156
- &normalizers[2 * i],
157
- &normalizers[2 * i + 1]);
158
- }
159
-
160
- for (uint64_t i = 0; i < n; i++) {
161
- const float* t_in = dis_tables.get() + i * dim12;
162
- uint8_t* t_out = lut + i * M2 * pq.ksub;
163
-
164
- for (int j = 0; j < dim12; j++) {
165
- t_out[j] = int(t_in[j]);
166
- }
167
- memset(t_out + dim12, 0, (M2 - pq.M) * pq.ksub);
168
- }
56
+ void IndexPQFastScan::compute_codes(uint8_t* codes, idx_t n, const float* x)
57
+ const {
58
+ pq.compute_codes(x, codes, n);
169
59
  }
170
60
 
171
- /******************************************************************************
172
- * Search driver routine
173
- ******************************************************************************/
174
-
175
- void IndexPQFastScan::search(
176
- idx_t n,
177
- const float* x,
178
- idx_t k,
179
- float* distances,
180
- idx_t* labels) const {
181
- FAISS_THROW_IF_NOT(k > 0);
182
-
61
+ void IndexPQFastScan::compute_float_LUT(float* lut, idx_t n, const float* x)
62
+ const {
183
63
  if (metric_type == METRIC_L2) {
184
- search_dispatch_implem<true>(n, x, k, distances, labels);
64
+ pq.compute_distance_tables(n, x, lut);
185
65
  } else {
186
- search_dispatch_implem<false>(n, x, k, distances, labels);
66
+ pq.compute_inner_prod_tables(n, x, lut);
187
67
  }
188
68
  }
189
69
 
190
- template <bool is_max>
191
- void IndexPQFastScan::search_dispatch_implem(
192
- idx_t n,
193
- const float* x,
194
- idx_t k,
195
- float* distances,
196
- idx_t* labels) const {
197
- using Cfloat = typename std::conditional<
198
- is_max,
199
- CMax<float, int64_t>,
200
- CMin<float, int64_t>>::type;
201
-
202
- using C = typename std::
203
- conditional<is_max, CMax<uint16_t, int>, CMin<uint16_t, int>>::type;
204
-
205
- if (n == 0) {
206
- return;
207
- }
208
-
209
- // actual implementation used
210
- int impl = implem;
211
-
212
- if (impl == 0) {
213
- if (bbs == 32) {
214
- impl = 12;
215
- } else {
216
- impl = 14;
217
- }
218
- if (k > 20) {
219
- impl++;
220
- }
221
- }
222
-
223
- if (implem == 1) {
224
- FAISS_THROW_IF_NOT(orig_codes);
225
- FAISS_THROW_IF_NOT(is_max);
226
- float_maxheap_array_t res = {size_t(n), size_t(k), labels, distances};
227
- pq.search(x, n, orig_codes, ntotal, &res, true);
228
- } else if (implem == 2 || implem == 3 || implem == 4) {
229
- FAISS_THROW_IF_NOT(orig_codes);
230
-
231
- size_t dim12 = pq.ksub * pq.M;
232
- std::unique_ptr<float[]> dis_tables(new float[n * dim12]);
233
- if (is_max) {
234
- pq.compute_distance_tables(n, x, dis_tables.get());
235
- } else {
236
- pq.compute_inner_prod_tables(n, x, dis_tables.get());
237
- }
238
-
239
- std::vector<float> normalizers(n * 2);
240
-
241
- if (implem == 2) {
242
- // default float
243
- } else if (implem == 3 || implem == 4) {
244
- for (uint64_t i = 0; i < n; i++) {
245
- round_uint8_per_column(
246
- dis_tables.get() + i * dim12,
247
- pq.M,
248
- pq.ksub,
249
- &normalizers[2 * i],
250
- &normalizers[2 * i + 1]);
251
- }
252
- }
253
-
254
- for (int64_t i = 0; i < n; i++) {
255
- int64_t* heap_ids = labels + i * k;
256
- float* heap_dis = distances + i * k;
257
-
258
- heap_heapify<Cfloat>(k, heap_dis, heap_ids);
259
-
260
- pq_estimators_from_tables_generic<Cfloat>(
261
- pq,
262
- pq.nbits,
263
- orig_codes,
264
- ntotal,
265
- dis_tables.get() + i * dim12,
266
- k,
267
- heap_dis,
268
- heap_ids);
269
-
270
- heap_reorder<Cfloat>(k, heap_dis, heap_ids);
271
-
272
- if (implem == 4) {
273
- float a = normalizers[2 * i];
274
- float b = normalizers[2 * i + 1];
275
-
276
- for (int j = 0; j < k; j++) {
277
- heap_dis[j] = heap_dis[j] / a + b;
278
- }
279
- }
280
- }
281
- } else if (impl >= 12 && impl <= 15) {
282
- FAISS_THROW_IF_NOT(ntotal < INT_MAX);
283
- int nt = std::min(omp_get_max_threads(), int(n));
284
- if (nt < 2) {
285
- if (impl == 12 || impl == 13) {
286
- search_implem_12<C>(n, x, k, distances, labels, impl);
287
- } else {
288
- search_implem_14<C>(n, x, k, distances, labels, impl);
289
- }
290
- } else {
291
- // explicitly slice over threads
292
- #pragma omp parallel for num_threads(nt)
293
- for (int slice = 0; slice < nt; slice++) {
294
- idx_t i0 = n * slice / nt;
295
- idx_t i1 = n * (slice + 1) / nt;
296
- float* dis_i = distances + i0 * k;
297
- idx_t* lab_i = labels + i0 * k;
298
- if (impl == 12 || impl == 13) {
299
- search_implem_12<C>(
300
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
301
- } else {
302
- search_implem_14<C>(
303
- i1 - i0, x + i0 * d, k, dis_i, lab_i, impl);
304
- }
305
- }
306
- }
307
- } else {
308
- FAISS_THROW_FMT("invalid implem %d impl=%d", implem, impl);
309
- }
310
- }
311
-
312
- template <class C>
313
- void IndexPQFastScan::search_implem_12(
314
- idx_t n,
315
- const float* x,
316
- idx_t k,
317
- float* distances,
318
- idx_t* labels,
319
- int impl) const {
320
- FAISS_THROW_IF_NOT(bbs == 32);
321
-
322
- // handle qbs2 blocking by recursive call
323
- int64_t qbs2 = this->qbs == 0 ? 11 : pq4_qbs_to_nq(this->qbs);
324
- if (n > qbs2) {
325
- for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
326
- int64_t i1 = std::min(i0 + qbs2, n);
327
- search_implem_12<C>(
328
- i1 - i0,
329
- x + d * i0,
330
- k,
331
- distances + i0 * k,
332
- labels + i0 * k,
333
- impl);
334
- }
335
- return;
336
- }
337
-
338
- size_t dim12 = pq.ksub * M2;
339
- AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
340
- std::unique_ptr<float[]> normalizers(new float[2 * n]);
341
-
342
- if (skip & 1) {
343
- quantized_dis_tables.clear();
344
- } else {
345
- compute_quantized_LUT(
346
- n, x, quantized_dis_tables.get(), normalizers.get());
347
- }
348
-
349
- AlignedTable<uint8_t> LUT(n * dim12);
350
-
351
- // block sizes are encoded in qbs, 4 bits at a time
352
-
353
- // caution: we override an object field
354
- int qbs = this->qbs;
355
-
356
- if (n != pq4_qbs_to_nq(qbs)) {
357
- qbs = pq4_preferred_qbs(n);
358
- }
359
-
360
- int LUT_nq =
361
- pq4_pack_LUT_qbs(qbs, M2, quantized_dis_tables.get(), LUT.get());
362
- FAISS_THROW_IF_NOT(LUT_nq == n);
363
-
364
- if (k == 1) {
365
- SingleResultHandler<C> handler(n, ntotal);
366
- if (skip & 4) {
367
- // pass
368
- } else {
369
- handler.disable = bool(skip & 2);
370
- pq4_accumulate_loop_qbs(
371
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
372
- }
373
-
374
- handler.to_flat_arrays(distances, labels, normalizers.get());
375
-
376
- } else if (impl == 12) {
377
- std::vector<uint16_t> tmp_dis(n * k);
378
- std::vector<int32_t> tmp_ids(n * k);
379
-
380
- if (skip & 4) {
381
- // skip
382
- } else {
383
- HeapHandler<C> handler(
384
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
385
- handler.disable = bool(skip & 2);
386
-
387
- pq4_accumulate_loop_qbs(
388
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
389
-
390
- if (!(skip & 8)) {
391
- handler.to_flat_arrays(distances, labels, normalizers.get());
392
- }
393
- }
394
-
395
- } else { // impl == 13
396
-
397
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
398
- handler.disable = bool(skip & 2);
399
-
400
- if (skip & 4) {
401
- // skip
402
- } else {
403
- pq4_accumulate_loop_qbs(
404
- qbs, ntotal2, M2, codes.get(), LUT.get(), handler);
405
- }
406
-
407
- if (!(skip & 8)) {
408
- handler.to_flat_arrays(distances, labels, normalizers.get());
409
- }
410
-
411
- FastScan_stats.t0 += handler.times[0];
412
- FastScan_stats.t1 += handler.times[1];
413
- FastScan_stats.t2 += handler.times[2];
414
- FastScan_stats.t3 += handler.times[3];
415
- }
416
- }
417
-
418
- FastScanStats FastScan_stats;
419
-
420
- template <class C>
421
- void IndexPQFastScan::search_implem_14(
422
- idx_t n,
423
- const float* x,
424
- idx_t k,
425
- float* distances,
426
- idx_t* labels,
427
- int impl) const {
428
- FAISS_THROW_IF_NOT(bbs % 32 == 0);
429
-
430
- int qbs2 = qbs == 0 ? 4 : qbs;
431
-
432
- // handle qbs2 blocking by recursive call
433
- if (n > qbs2) {
434
- for (int64_t i0 = 0; i0 < n; i0 += qbs2) {
435
- int64_t i1 = std::min(i0 + qbs2, n);
436
- search_implem_14<C>(
437
- i1 - i0,
438
- x + d * i0,
439
- k,
440
- distances + i0 * k,
441
- labels + i0 * k,
442
- impl);
443
- }
444
- return;
445
- }
446
-
447
- size_t dim12 = pq.ksub * M2;
448
- AlignedTable<uint8_t> quantized_dis_tables(n * dim12);
449
- std::unique_ptr<float[]> normalizers(new float[2 * n]);
450
-
451
- if (skip & 1) {
452
- quantized_dis_tables.clear();
453
- } else {
454
- compute_quantized_LUT(
455
- n, x, quantized_dis_tables.get(), normalizers.get());
456
- }
457
-
458
- AlignedTable<uint8_t> LUT(n * dim12);
459
- pq4_pack_LUT(n, M2, quantized_dis_tables.get(), LUT.get());
460
-
461
- if (k == 1) {
462
- SingleResultHandler<C> handler(n, ntotal);
463
- if (skip & 4) {
464
- // pass
465
- } else {
466
- handler.disable = bool(skip & 2);
467
- pq4_accumulate_loop(
468
- n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
469
- }
470
- handler.to_flat_arrays(distances, labels, normalizers.get());
471
-
472
- } else if (impl == 14) {
473
- std::vector<uint16_t> tmp_dis(n * k);
474
- std::vector<int32_t> tmp_ids(n * k);
475
-
476
- if (skip & 4) {
477
- // skip
478
- } else if (k > 1) {
479
- HeapHandler<C> handler(
480
- n, tmp_dis.data(), tmp_ids.data(), k, ntotal);
481
- handler.disable = bool(skip & 2);
482
-
483
- pq4_accumulate_loop(
484
- n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
485
-
486
- if (!(skip & 8)) {
487
- handler.to_flat_arrays(distances, labels, normalizers.get());
488
- }
489
- }
490
-
491
- } else { // impl == 15
492
-
493
- ReservoirHandler<C> handler(n, ntotal, k, 2 * k);
494
- handler.disable = bool(skip & 2);
495
-
496
- if (skip & 4) {
497
- // skip
498
- } else {
499
- pq4_accumulate_loop(
500
- n, ntotal2, bbs, M2, codes.get(), LUT.get(), handler);
501
- }
502
-
503
- if (!(skip & 8)) {
504
- handler.to_flat_arrays(distances, labels, normalizers.get());
505
- }
506
- }
70
+ void IndexPQFastScan::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
71
+ pq.decode(bytes, x, n);
507
72
  }
508
73
 
509
74
  } // namespace faiss
@@ -7,6 +7,7 @@
7
7
 
8
8
  #pragma once
9
9
 
10
+ #include <faiss/IndexFastScan.h>
10
11
  #include <faiss/IndexPQ.h>
11
12
  #include <faiss/impl/ProductQuantizer.h>
12
13
  #include <faiss/utils/AlignedTable.h>
@@ -25,27 +26,9 @@ namespace faiss {
25
26
  * 15: no qbs with reservoir accumulator
26
27
  */
27
28
 
28
- struct IndexPQFastScan : Index {
29
+ struct IndexPQFastScan : IndexFastScan {
29
30
  ProductQuantizer pq;
30
31
 
31
- // implementation to select
32
- int implem = 0;
33
- // skip some parts of the computation (for timing)
34
- int skip = 0;
35
-
36
- // size of the kernel
37
- int bbs; // set at build time
38
- int qbs = 0; // query block size 0 = use default
39
-
40
- // packed version of the codes
41
- size_t ntotal2;
42
- size_t M2;
43
-
44
- AlignedTable<uint8_t> codes;
45
-
46
- // this is for testing purposes only (set when initialized by IndexPQ)
47
- const uint8_t* orig_codes = nullptr;
48
-
49
32
  IndexPQFastScan(
50
33
  int d,
51
34
  size_t M,
@@ -53,73 +36,27 @@ struct IndexPQFastScan : Index {
53
36
  MetricType metric = METRIC_L2,
54
37
  int bbs = 32);
55
38
 
56
- IndexPQFastScan();
39
+ IndexPQFastScan() = default;
57
40
 
58
41
  /// build from an existing IndexPQ
59
42
  explicit IndexPQFastScan(const IndexPQ& orig, int bbs = 32);
60
43
 
61
44
  void train(idx_t n, const float* x) override;
62
- void add(idx_t n, const float* x) override;
63
- void reset() override;
64
- void search(
65
- idx_t n,
66
- const float* x,
67
- idx_t k,
68
- float* distances,
69
- idx_t* labels) const override;
70
45
 
71
- // called by search function
72
- void compute_quantized_LUT(
73
- idx_t n,
74
- const float* x,
75
- uint8_t* lut,
76
- float* normalizers) const;
46
+ void compute_codes(uint8_t* codes, idx_t n, const float* x) const override;
77
47
 
78
- template <bool is_max>
79
- void search_dispatch_implem(
80
- idx_t n,
81
- const float* x,
82
- idx_t k,
83
- float* distances,
84
- idx_t* labels) const;
48
+ void compute_float_LUT(float* lut, idx_t n, const float* x) const override;
85
49
 
86
- template <class C>
87
- void search_implem_2(
88
- idx_t n,
89
- const float* x,
90
- idx_t k,
91
- float* distances,
92
- idx_t* labels) const;
93
-
94
- template <class C>
95
- void search_implem_12(
96
- idx_t n,
97
- const float* x,
98
- idx_t k,
99
- float* distances,
100
- idx_t* labels,
101
- int impl) const;
102
-
103
- template <class C>
104
- void search_implem_14(
105
- idx_t n,
106
- const float* x,
107
- idx_t k,
108
- float* distances,
109
- idx_t* labels,
110
- int impl) const;
50
+ /** Decode a set of vectors.
51
+ *
52
+ * NOTE: The codes in the IndexPQFastScan object are non-contiguous.
53
+ * But this method requires a contiguous representation.
54
+ *
55
+ * @param n number of vectors
56
+ * @param bytes input encoded vectors, size n * code_size
57
+ * @param x output vectors, size n * d
58
+ */
59
+ void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
111
60
  };
112
61
 
113
- struct FastScanStats {
114
- uint64_t t0, t1, t2, t3;
115
- FastScanStats() {
116
- reset();
117
- }
118
- void reset() {
119
- memset(this, 0, sizeof(*this));
120
- }
121
- };
122
-
123
- FAISS_API extern FastScanStats FastScan_stats;
124
-
125
62
  } // namespace faiss