faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -40,6 +40,90 @@ IndexAdditiveQuantizer::IndexAdditiveQuantizer(
40
40
 
41
41
  namespace {
42
42
 
43
+ /************************************************************
44
+ * DistanceComputer implementation
45
+ ************************************************************/
46
+
47
+ template <class VectorDistance>
48
+ struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
49
+ std::vector<float> tmp;
50
+ const AdditiveQuantizer & aq;
51
+ VectorDistance vd;
52
+ size_t d;
53
+
54
+ AQDistanceComputerDecompress(const IndexAdditiveQuantizer &iaq, VectorDistance vd):
55
+ FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
56
+ tmp(iaq.d * 2),
57
+ aq(*iaq.aq),
58
+ vd(vd),
59
+ d(iaq.d)
60
+ {}
61
+
62
+ const float *q;
63
+ void set_query(const float* x) final {
64
+ q = x;
65
+ }
66
+
67
+ float symmetric_dis(idx_t i, idx_t j) final {
68
+ aq.decode(codes + i * d, tmp.data(), 1);
69
+ aq.decode(codes + j * d, tmp.data() + d, 1);
70
+ return vd(tmp.data(), tmp.data() + d);
71
+ }
72
+
73
+ float distance_to_code(const uint8_t *code) final {
74
+ aq.decode(code, tmp.data(), 1);
75
+ return vd(q, tmp.data());
76
+ }
77
+
78
+ virtual ~AQDistanceComputerDecompress() {}
79
+ };
80
+
81
+
82
+ template<bool is_IP, AdditiveQuantizer::Search_type_t st>
83
+ struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
84
+ std::vector<float> LUT;
85
+ const AdditiveQuantizer & aq;
86
+ size_t d;
87
+
88
+ explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer &iaq):
89
+ FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
90
+ LUT(iaq.aq->total_codebook_size + iaq.d * 2),
91
+ aq(*iaq.aq),
92
+ d(iaq.d)
93
+ {}
94
+
95
+ float bias;
96
+ void set_query(const float* x) final {
97
+ // this is quite sub-optimal for multiple queries
98
+ aq.compute_LUT(1, x, LUT.data());
99
+ if (is_IP) {
100
+ bias = 0;
101
+ } else {
102
+ bias = fvec_norm_L2sqr(x, d);
103
+ }
104
+ }
105
+
106
+ float symmetric_dis(idx_t i, idx_t j) final {
107
+ float *tmp = LUT.data();
108
+ aq.decode(codes + i * d, tmp, 1);
109
+ aq.decode(codes + j * d, tmp + d, 1);
110
+ return fvec_L2sqr(tmp, tmp + d, d);
111
+ }
112
+
113
+ float distance_to_code(const uint8_t *code) final {
114
+ return bias + aq.compute_1_distance_LUT<is_IP, st>(code, LUT.data());
115
+ }
116
+
117
+ virtual ~AQDistanceComputerLUT() {}
118
+ };
119
+
120
+
121
+
122
+ /************************************************************
123
+ * scanning implementation for search
124
+ ************************************************************/
125
+
126
+
43
127
  template <class VectorDistance, class ResultHandler>
44
128
  void search_with_decompress(
45
129
  const IndexAdditiveQuantizer& ir,
@@ -111,12 +195,61 @@ void search_with_LUT(
111
195
 
112
196
  } // anonymous namespace
113
197
 
198
+
199
+ FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceComputer() const {
200
+
201
+ if (aq->search_type == AdditiveQuantizer::ST_decompress) {
202
+ if (metric_type == METRIC_L2) {
203
+ using VD = VectorDistance<METRIC_L2>;
204
+ VD vd = {size_t(d), metric_arg};
205
+ return new AQDistanceComputerDecompress<VD>(*this, vd);
206
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
207
+ using VD = VectorDistance<METRIC_INNER_PRODUCT>;
208
+ VD vd = {size_t(d), metric_arg};
209
+ return new AQDistanceComputerDecompress<VD>(*this, vd);
210
+ } else {
211
+ FAISS_THROW_MSG("unsupported metric");
212
+ }
213
+ } else {
214
+ if (metric_type == METRIC_INNER_PRODUCT) {
215
+ return new AQDistanceComputerLUT<true, AdditiveQuantizer::ST_LUT_nonorm>(*this);
216
+ } else {
217
+ switch(aq->search_type) {
218
+ #define DISPATCH(st) \
219
+ case AdditiveQuantizer::st: \
220
+ return new AQDistanceComputerLUT<false, AdditiveQuantizer::st> (*this);\
221
+ break;
222
+ DISPATCH(ST_norm_float)
223
+ DISPATCH(ST_LUT_nonorm)
224
+ DISPATCH(ST_norm_qint8)
225
+ DISPATCH(ST_norm_qint4)
226
+ DISPATCH(ST_norm_cqint4)
227
+ case AdditiveQuantizer::ST_norm_cqint8:
228
+ case AdditiveQuantizer::ST_norm_lsq2x4:
229
+ case AdditiveQuantizer::ST_norm_rq2x4:
230
+ return new AQDistanceComputerLUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this);\
231
+ break;
232
+ #undef DISPATCH
233
+ default:
234
+ FAISS_THROW_FMT("search type %d not supported", aq->search_type);
235
+ }
236
+ }
237
+ }
238
+ }
239
+
240
+
241
+
242
+
114
243
  void IndexAdditiveQuantizer::search(
115
244
  idx_t n,
116
245
  const float* x,
117
246
  idx_t k,
118
247
  float* distances,
119
- idx_t* labels) const {
248
+ idx_t* labels,
249
+ const SearchParameters* params) const {
250
+
251
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
252
+
120
253
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
121
254
  if (metric_type == METRIC_L2) {
122
255
  using VD = VectorDistance<METRIC_L2>;
@@ -135,20 +268,23 @@ void IndexAdditiveQuantizer::search(
135
268
  search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm> (*this, x, rh);
136
269
  } else {
137
270
  HeapResultHandler<CMax<float, idx_t> > rh(n, distances, labels, k);
138
-
139
- if (aq->search_type == AdditiveQuantizer::ST_norm_float) {
140
- search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
141
- } else if (aq->search_type == AdditiveQuantizer::ST_LUT_nonorm) {
142
- search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
143
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint8) {
144
- search_with_LUT<false, AdditiveQuantizer::ST_norm_qint8> (*this, x, rh);
145
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint4) {
146
- search_with_LUT<false, AdditiveQuantizer::ST_norm_qint4> (*this, x, rh);
147
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8) {
271
+ switch(aq->search_type) {
272
+ #define DISPATCH(st) \
273
+ case AdditiveQuantizer::st: \
274
+ search_with_LUT<false, AdditiveQuantizer::st> (*this, x, rh);\
275
+ break;
276
+ DISPATCH(ST_norm_float)
277
+ DISPATCH(ST_LUT_nonorm)
278
+ DISPATCH(ST_norm_qint8)
279
+ DISPATCH(ST_norm_qint4)
280
+ DISPATCH(ST_norm_cqint4)
281
+ case AdditiveQuantizer::ST_norm_cqint8:
282
+ case AdditiveQuantizer::ST_norm_lsq2x4:
283
+ case AdditiveQuantizer::ST_norm_rq2x4:
148
284
  search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this, x, rh);
149
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
150
- search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint4> (*this, x, rh);
151
- } else {
285
+ break;
286
+ #undef DISPATCH
287
+ default:
152
288
  FAISS_THROW_FMT("search type %d not supported", aq->search_type);
153
289
  }
154
290
  }
@@ -220,6 +356,57 @@ void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
220
356
  is_trained = true;
221
357
  }
222
358
 
359
+
360
+ /**************************************************************************************
361
+ * IndexProductResidualQuantizer
362
+ **************************************************************************************/
363
+
364
+ IndexProductResidualQuantizer::IndexProductResidualQuantizer(
365
+ int d, ///< dimensionality of the input vectors
366
+ size_t nsplits, ///< number of residual quantizers
367
+ size_t Msub, ///< number of subquantizers per RQ
368
+ size_t nbits, ///< number of bit per subvector index
369
+ MetricType metric,
370
+ Search_type_t search_type)
371
+ : IndexAdditiveQuantizer(d, &prq, metric), prq(d, nsplits, Msub, nbits, search_type) {
372
+ code_size = prq.code_size;
373
+ is_trained = false;
374
+ }
375
+
376
+ IndexProductResidualQuantizer::IndexProductResidualQuantizer()
377
+ : IndexProductResidualQuantizer(0, 0, 0, 0) {}
378
+
379
+ void IndexProductResidualQuantizer::train(idx_t n, const float* x) {
380
+ prq.train(n, x);
381
+ is_trained = true;
382
+ }
383
+
384
+
385
+ /**************************************************************************************
386
+ * IndexProductLocalSearchQuantizer
387
+ **************************************************************************************/
388
+
389
+ IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer(
390
+ int d, ///< dimensionality of the input vectors
391
+ size_t nsplits, ///< number of local search quantizers
392
+ size_t Msub, ///< number of subquantizers per LSQ
393
+ size_t nbits, ///< number of bit per subvector index
394
+ MetricType metric,
395
+ Search_type_t search_type)
396
+ : IndexAdditiveQuantizer(d, &plsq, metric), plsq(d, nsplits, Msub, nbits, search_type) {
397
+ code_size = plsq.code_size;
398
+ is_trained = false;
399
+ }
400
+
401
+ IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer()
402
+ : IndexProductLocalSearchQuantizer(0, 0, 0, 0) {}
403
+
404
+ void IndexProductLocalSearchQuantizer::train(idx_t n, const float* x) {
405
+ plsq.train(n, x);
406
+ is_trained = true;
407
+ }
408
+
409
+
223
410
  /**************************************************************************************
224
411
  * AdditiveCoarseQuantizer
225
412
  **************************************************************************************/
@@ -248,6 +435,13 @@ void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
248
435
  if (verbose) {
249
436
  printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n", size_t(n));
250
437
  }
438
+ size_t norms_size = sizeof(float) << aq->tot_bits;
439
+
440
+ FAISS_THROW_IF_NOT_MSG (
441
+ norms_size <= aq->max_mem_distances,
442
+ "the RCQ norms matrix will become too large, please reduce the number of quantization steps"
443
+ );
444
+
251
445
  aq->train(n, x);
252
446
  is_trained = true;
253
447
  ntotal = (idx_t)1 << aq->tot_bits;
@@ -268,7 +462,11 @@ void AdditiveCoarseQuantizer::search(
268
462
  const float* x,
269
463
  idx_t k,
270
464
  float* distances,
271
- idx_t* labels) const {
465
+ idx_t* labels,
466
+ const SearchParameters * params) const {
467
+
468
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
469
+
272
470
  if (metric_type == METRIC_INNER_PRODUCT) {
273
471
  aq->knn_centroids_inner_product(n, x, k, distances, labels);
274
472
  } else if (metric_type == METRIC_L2) {
@@ -321,7 +519,12 @@ void ResidualCoarseQuantizer::search(
321
519
  const float* x,
322
520
  idx_t k,
323
521
  float* distances,
324
- idx_t* labels) const {
522
+ idx_t* labels,
523
+ const SearchParameters * params
524
+ ) const {
525
+
526
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
527
+
325
528
  if (beam_factor < 0) {
326
529
  AdditiveCoarseQuantizer::search(n, x, k, distances, labels);
327
530
  return;
@@ -15,6 +15,7 @@
15
15
 
16
16
  #include <faiss/IndexFlatCodes.h>
17
17
  #include <faiss/impl/LocalSearchQuantizer.h>
18
+ #include <faiss/impl/ProductAdditiveQuantizer.h>
18
19
  #include <faiss/impl/ResidualQuantizer.h>
19
20
  #include <faiss/impl/platform_macros.h>
20
21
 
@@ -28,8 +29,8 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
28
29
  using Search_type_t = AdditiveQuantizer::Search_type_t;
29
30
 
30
31
  explicit IndexAdditiveQuantizer(
31
- idx_t d = 0,
32
- AdditiveQuantizer* aq = nullptr,
32
+ idx_t d,
33
+ AdditiveQuantizer* aq,
33
34
  MetricType metric = METRIC_L2);
34
35
 
35
36
  void search(
@@ -37,12 +38,15 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
37
38
  const float* x,
38
39
  idx_t k,
39
40
  float* distances,
40
- idx_t* labels) const override;
41
+ idx_t* labels,
42
+ const SearchParameters* params = nullptr) const override;
41
43
 
42
44
  /* The standalone codec interface */
43
45
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
44
46
 
45
47
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
48
+
49
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
46
50
  };
47
51
 
48
52
  /** Index based on a residual quantizer. Stored vectors are
@@ -98,6 +102,58 @@ struct IndexLocalSearchQuantizer : IndexAdditiveQuantizer {
98
102
  void train(idx_t n, const float* x) override;
99
103
  };
100
104
 
105
+ /** Index based on a product residual quantizer.
106
+ */
107
+ struct IndexProductResidualQuantizer : IndexAdditiveQuantizer {
108
+ /// The product residual quantizer used to encode the vectors
109
+ ProductResidualQuantizer prq;
110
+
111
+ /** Constructor.
112
+ *
113
+ * @param d dimensionality of the input vectors
114
+ * @param nsplits number of residual quantizers
115
+ * @param Msub number of subquantizers per RQ
116
+ * @param nbits number of bit per subvector index
117
+ */
118
+ IndexProductResidualQuantizer(
119
+ int d, ///< dimensionality of the input vectors
120
+ size_t nsplits, ///< number of residual quantizers
121
+ size_t Msub, ///< number of subquantizers per RQ
122
+ size_t nbits, ///< number of bit per subvector index
123
+ MetricType metric = METRIC_L2,
124
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
125
+
126
+ IndexProductResidualQuantizer();
127
+
128
+ void train(idx_t n, const float* x) override;
129
+ };
130
+
131
+ /** Index based on a product local search quantizer.
132
+ */
133
+ struct IndexProductLocalSearchQuantizer : IndexAdditiveQuantizer {
134
+ /// The product local search quantizer used to encode the vectors
135
+ ProductLocalSearchQuantizer plsq;
136
+
137
+ /** Constructor.
138
+ *
139
+ * @param d dimensionality of the input vectors
140
+ * @param nsplits number of local search quantizers
141
+ * @param Msub number of subquantizers per LSQ
142
+ * @param nbits number of bit per subvector index
143
+ */
144
+ IndexProductLocalSearchQuantizer(
145
+ int d, ///< dimensionality of the input vectors
146
+ size_t nsplits, ///< number of local search quantizers
147
+ size_t Msub, ///< number of subquantizers per LSQ
148
+ size_t nbits, ///< number of bit per subvector index
149
+ MetricType metric = METRIC_L2,
150
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
151
+
152
+ IndexProductLocalSearchQuantizer();
153
+
154
+ void train(idx_t n, const float* x) override;
155
+ };
156
+
101
157
  /** A "virtual" index where the elements are the residual quantizer centroids.
102
158
  *
103
159
  * Intended for use as a coarse quantizer in an IndexIVF.
@@ -121,7 +177,8 @@ struct AdditiveCoarseQuantizer : Index {
121
177
  const float* x,
122
178
  idx_t k,
123
179
  float* distances,
124
- idx_t* labels) const override;
180
+ idx_t* labels,
181
+ const SearchParameters* params = nullptr) const override;
125
182
 
126
183
  void reconstruct(idx_t key, float* recons) const override;
127
184
  void train(idx_t n, const float* x) override;
@@ -166,7 +223,8 @@ struct ResidualCoarseQuantizer : AdditiveCoarseQuantizer {
166
223
  const float* x,
167
224
  idx_t k,
168
225
  float* distances,
169
- idx_t* labels) const override;
226
+ idx_t* labels,
227
+ const SearchParameters* params = nullptr) const override;
170
228
 
171
229
  ResidualCoarseQuantizer();
172
230
  };
@@ -0,0 +1,299 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexAdditiveQuantizerFastScan.h>
9
+
10
+ #include <limits.h>
11
+ #include <cassert>
12
+ #include <memory>
13
+
14
+ #include <omp.h>
15
+
16
+ #include <faiss/impl/FaissAssert.h>
17
+ #include <faiss/impl/LocalSearchQuantizer.h>
18
+ #include <faiss/impl/LookupTableScaler.h>
19
+ #include <faiss/impl/ResidualQuantizer.h>
20
+ #include <faiss/impl/pq4_fast_scan.h>
21
+ #include <faiss/utils/quantize_lut.h>
22
+ #include <faiss/utils/utils.h>
23
+
24
+ namespace faiss {
25
+
26
+ inline size_t roundup(size_t a, size_t b) {
27
+ return (a + b - 1) / b * b;
28
+ }
29
+
30
+ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
31
+ AdditiveQuantizer* aq,
32
+ MetricType metric,
33
+ int bbs) {
34
+ init(aq, metric, bbs);
35
+ }
36
+
37
+ void IndexAdditiveQuantizerFastScan::init(
38
+ AdditiveQuantizer* aq,
39
+ MetricType metric,
40
+ int bbs) {
41
+ FAISS_THROW_IF_NOT(aq != nullptr);
42
+ FAISS_THROW_IF_NOT(!aq->nbits.empty());
43
+ FAISS_THROW_IF_NOT(aq->nbits[0] == 4);
44
+ if (metric == METRIC_INNER_PRODUCT) {
45
+ FAISS_THROW_IF_NOT_MSG(
46
+ aq->search_type == AdditiveQuantizer::ST_LUT_nonorm,
47
+ "Search type must be ST_LUT_nonorm for IP metric");
48
+ } else {
49
+ FAISS_THROW_IF_NOT_MSG(
50
+ aq->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
51
+ aq->search_type == AdditiveQuantizer::ST_norm_rq2x4,
52
+ "Search type must be lsq2x4 or rq2x4 for L2 metric");
53
+ }
54
+
55
+ this->aq = aq;
56
+ if (metric == METRIC_L2) {
57
+ M = aq->M + 2; // 2x4 bits AQ
58
+ } else {
59
+ M = aq->M;
60
+ }
61
+ init_fastscan(aq->d, M, 4, metric, bbs);
62
+
63
+ max_train_points = 1024 * ksub * M;
64
+ }
65
+
66
+ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan()
67
+ : IndexFastScan() {
68
+ is_trained = false;
69
+ aq = nullptr;
70
+ }
71
+
72
+ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
73
+ const IndexAdditiveQuantizer& orig,
74
+ int bbs) {
75
+ init(orig.aq, orig.metric_type, bbs);
76
+
77
+ ntotal = orig.ntotal;
78
+ is_trained = orig.is_trained;
79
+ orig_codes = orig.codes.data();
80
+
81
+ ntotal2 = roundup(ntotal, bbs);
82
+ codes.resize(ntotal2 * M2 / 2);
83
+ pq4_pack_codes(orig_codes, ntotal, M, ntotal2, bbs, M2, codes.get());
84
+ }
85
+
86
+ IndexAdditiveQuantizerFastScan::~IndexAdditiveQuantizerFastScan() {}
87
+
88
+ void IndexAdditiveQuantizerFastScan::train(idx_t n, const float* x_in) {
89
+ if (is_trained) {
90
+ return;
91
+ }
92
+
93
+ const int seed = 0x12345;
94
+ size_t nt = n;
95
+ const float* x = fvecs_maybe_subsample(
96
+ d, &nt, max_train_points, x_in, verbose, seed);
97
+ n = nt;
98
+ if (verbose) {
99
+ printf("training additive quantizer on %zd vectors\n", nt);
100
+ }
101
+
102
+ aq->verbose = verbose;
103
+ aq->train(n, x);
104
+ if (metric_type == METRIC_L2) {
105
+ estimate_norm_scale(n, x);
106
+ }
107
+
108
+ is_trained = true;
109
+ }
110
+
111
+ void IndexAdditiveQuantizerFastScan::estimate_norm_scale(
112
+ idx_t n,
113
+ const float* x_in) {
114
+ FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
115
+
116
+ constexpr int seed = 0x980903;
117
+ constexpr size_t max_points_estimated = 65536;
118
+ size_t ns = n;
119
+ const float* x = fvecs_maybe_subsample(
120
+ d, &ns, max_points_estimated, x_in, verbose, seed);
121
+ n = ns;
122
+ std::unique_ptr<float[]> del_x;
123
+ if (x != x_in) {
124
+ del_x.reset((float*)x);
125
+ }
126
+
127
+ std::vector<float> dis_tables(n * M * ksub);
128
+ compute_float_LUT(dis_tables.data(), n, x);
129
+
130
+ // here we compute the mean of scales for each query
131
+ // TODO: try max of scales
132
+ double scale = 0;
133
+
134
+ #pragma omp parallel for reduction(+ : scale)
135
+ for (idx_t i = 0; i < n; i++) {
136
+ const float* lut = dis_tables.data() + i * M * ksub;
137
+ scale += quantize_lut::aq_estimate_norm_scale(M, ksub, 2, lut);
138
+ }
139
+ scale /= n;
140
+ norm_scale = (int)std::roundf(std::max(scale, 1.0));
141
+
142
+ if (verbose) {
143
+ printf("estimated norm scale: %lf\n", scale);
144
+ printf("rounded norm scale: %d\n", norm_scale);
145
+ }
146
+ }
147
+
148
+ void IndexAdditiveQuantizerFastScan::compute_codes(
149
+ uint8_t* tmp_codes,
150
+ idx_t n,
151
+ const float* x) const {
152
+ aq->compute_codes(x, tmp_codes, n);
153
+ }
154
+
155
+ void IndexAdditiveQuantizerFastScan::compute_float_LUT(
156
+ float* lut,
157
+ idx_t n,
158
+ const float* x) const {
159
+ if (metric_type == METRIC_INNER_PRODUCT) {
160
+ aq->compute_LUT(n, x, lut, 1.0f);
161
+ } else {
162
+ // compute inner product look-up tables
163
+ const size_t ip_dim12 = aq->M * ksub;
164
+ const size_t norm_dim12 = 2 * ksub;
165
+ std::vector<float> ip_lut(n * ip_dim12);
166
+ aq->compute_LUT(n, x, ip_lut.data(), -2.0f);
167
+
168
+ // copy and rescale norm look-up tables
169
+ auto norm_tabs = aq->norm_tabs;
170
+ if (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2) {
171
+ for (size_t i = 0; i < norm_tabs.size(); i++) {
172
+ norm_tabs[i] /= norm_scale;
173
+ }
174
+ }
175
+ const float* norm_lut = norm_tabs.data();
176
+ FAISS_THROW_IF_NOT(norm_tabs.size() == norm_dim12);
177
+
178
+ // combine them
179
+ for (idx_t i = 0; i < n; i++) {
180
+ memcpy(lut, ip_lut.data() + i * ip_dim12, ip_dim12 * sizeof(*lut));
181
+ lut += ip_dim12;
182
+ memcpy(lut, norm_lut, norm_dim12 * sizeof(*lut));
183
+ lut += norm_dim12;
184
+ }
185
+ }
186
+ }
187
+
188
+ void IndexAdditiveQuantizerFastScan::search(
189
+ idx_t n,
190
+ const float* x,
191
+ idx_t k,
192
+ float* distances,
193
+ idx_t* labels,
194
+ const SearchParameters* params) const {
195
+ FAISS_THROW_IF_NOT_MSG(
196
+ !params, "search params not supported for this index");
197
+ FAISS_THROW_IF_NOT(k > 0);
198
+ bool rescale = (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2);
199
+ if (!rescale) {
200
+ IndexFastScan::search(n, x, k, distances, labels);
201
+ return;
202
+ }
203
+
204
+ NormTableScaler scaler(norm_scale);
205
+ if (metric_type == METRIC_L2) {
206
+ search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
207
+ } else {
208
+ search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
209
+ }
210
+ }
211
+
212
+ void IndexAdditiveQuantizerFastScan::sa_decode(
213
+ idx_t n,
214
+ const uint8_t* bytes,
215
+ float* x) const {
216
+ aq->decode(bytes, x, n);
217
+ }
218
+
219
+ /**************************************************************************************
220
+ * IndexResidualQuantizerFastScan
221
+ **************************************************************************************/
222
+
223
+ IndexResidualQuantizerFastScan::IndexResidualQuantizerFastScan(
224
+ int d, ///< dimensionality of the input vectors
225
+ size_t M, ///< number of subquantizers
226
+ size_t nbits, ///< number of bit per subvector index
227
+ MetricType metric,
228
+ Search_type_t search_type,
229
+ int bbs)
230
+ : rq(d, M, nbits, search_type) {
231
+ init(&rq, metric, bbs);
232
+ }
233
+
234
+ IndexResidualQuantizerFastScan::IndexResidualQuantizerFastScan() {
235
+ aq = &rq;
236
+ }
237
+
238
+ /**************************************************************************************
239
+ * IndexLocalSearchQuantizerFastScan
240
+ **************************************************************************************/
241
+
242
+ IndexLocalSearchQuantizerFastScan::IndexLocalSearchQuantizerFastScan(
243
+ int d,
244
+ size_t M, ///< number of subquantizers
245
+ size_t nbits, ///< number of bit per subvector index
246
+ MetricType metric,
247
+ Search_type_t search_type,
248
+ int bbs)
249
+ : lsq(d, M, nbits, search_type) {
250
+ init(&lsq, metric, bbs);
251
+ }
252
+
253
+ IndexLocalSearchQuantizerFastScan::IndexLocalSearchQuantizerFastScan() {
254
+ aq = &lsq;
255
+ }
256
+
257
+ /**************************************************************************************
258
+ * IndexProductResidualQuantizerFastScan
259
+ **************************************************************************************/
260
+
261
+ IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan(
262
+ int d, ///< dimensionality of the input vectors
263
+ size_t nsplits, ///< number of residual quantizers
264
+ size_t Msub, ///< number of subquantizers per RQ
265
+ size_t nbits, ///< number of bit per subvector index
266
+ MetricType metric,
267
+ Search_type_t search_type,
268
+ int bbs)
269
+ : prq(d, nsplits, Msub, nbits, search_type) {
270
+ init(&prq, metric, bbs);
271
+ }
272
+
273
+ IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan() {
274
+ aq = &prq;
275
+ }
276
+
277
+ /**************************************************************************************
278
+ * IndexProductLocalSearchQuantizerFastScan
279
+ **************************************************************************************/
280
+
281
+ IndexProductLocalSearchQuantizerFastScan::
282
+ IndexProductLocalSearchQuantizerFastScan(
283
+ int d, ///< dimensionality of the input vectors
284
+ size_t nsplits, ///< number of local search quantizers
285
+ size_t Msub, ///< number of subquantizers per LSQ
286
+ size_t nbits, ///< number of bit per subvector index
287
+ MetricType metric,
288
+ Search_type_t search_type,
289
+ int bbs)
290
+ : plsq(d, nsplits, Msub, nbits, search_type) {
291
+ init(&plsq, metric, bbs);
292
+ }
293
+
294
+ IndexProductLocalSearchQuantizerFastScan::
295
+ IndexProductLocalSearchQuantizerFastScan() {
296
+ aq = &plsq;
297
+ }
298
+
299
+ } // namespace faiss