faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -40,6 +40,90 @@ IndexAdditiveQuantizer::IndexAdditiveQuantizer(
40
40
 
41
41
  namespace {
42
42
 
43
+ /************************************************************
44
+ * DistanceComputer implementation
45
+ ************************************************************/
46
+
47
+ template <class VectorDistance>
48
+ struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
49
+ std::vector<float> tmp;
50
+ const AdditiveQuantizer & aq;
51
+ VectorDistance vd;
52
+ size_t d;
53
+
54
+ AQDistanceComputerDecompress(const IndexAdditiveQuantizer &iaq, VectorDistance vd):
55
+ FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
56
+ tmp(iaq.d * 2),
57
+ aq(*iaq.aq),
58
+ vd(vd),
59
+ d(iaq.d)
60
+ {}
61
+
62
+ const float *q;
63
+ void set_query(const float* x) final {
64
+ q = x;
65
+ }
66
+
67
+ float symmetric_dis(idx_t i, idx_t j) final {
68
+ aq.decode(codes + i * d, tmp.data(), 1);
69
+ aq.decode(codes + j * d, tmp.data() + d, 1);
70
+ return vd(tmp.data(), tmp.data() + d);
71
+ }
72
+
73
+ float distance_to_code(const uint8_t *code) final {
74
+ aq.decode(code, tmp.data(), 1);
75
+ return vd(q, tmp.data());
76
+ }
77
+
78
+ virtual ~AQDistanceComputerDecompress() {}
79
+ };
80
+
81
+
82
+ template<bool is_IP, AdditiveQuantizer::Search_type_t st>
83
+ struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
84
+ std::vector<float> LUT;
85
+ const AdditiveQuantizer & aq;
86
+ size_t d;
87
+
88
+ explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer &iaq):
89
+ FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
90
+ LUT(iaq.aq->total_codebook_size + iaq.d * 2),
91
+ aq(*iaq.aq),
92
+ d(iaq.d)
93
+ {}
94
+
95
+ float bias;
96
+ void set_query(const float* x) final {
97
+ // this is quite sub-optimal for multiple queries
98
+ aq.compute_LUT(1, x, LUT.data());
99
+ if (is_IP) {
100
+ bias = 0;
101
+ } else {
102
+ bias = fvec_norm_L2sqr(x, d);
103
+ }
104
+ }
105
+
106
+ float symmetric_dis(idx_t i, idx_t j) final {
107
+ float *tmp = LUT.data();
108
+ aq.decode(codes + i * d, tmp, 1);
109
+ aq.decode(codes + j * d, tmp + d, 1);
110
+ return fvec_L2sqr(tmp, tmp + d, d);
111
+ }
112
+
113
+ float distance_to_code(const uint8_t *code) final {
114
+ return bias + aq.compute_1_distance_LUT<is_IP, st>(code, LUT.data());
115
+ }
116
+
117
+ virtual ~AQDistanceComputerLUT() {}
118
+ };
119
+
120
+
121
+
122
+ /************************************************************
123
+ * scanning implementation for search
124
+ ************************************************************/
125
+
126
+
43
127
  template <class VectorDistance, class ResultHandler>
44
128
  void search_with_decompress(
45
129
  const IndexAdditiveQuantizer& ir,
@@ -111,12 +195,61 @@ void search_with_LUT(
111
195
 
112
196
  } // anonymous namespace
113
197
 
198
+
199
+ FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceComputer() const {
200
+
201
+ if (aq->search_type == AdditiveQuantizer::ST_decompress) {
202
+ if (metric_type == METRIC_L2) {
203
+ using VD = VectorDistance<METRIC_L2>;
204
+ VD vd = {size_t(d), metric_arg};
205
+ return new AQDistanceComputerDecompress<VD>(*this, vd);
206
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
207
+ using VD = VectorDistance<METRIC_INNER_PRODUCT>;
208
+ VD vd = {size_t(d), metric_arg};
209
+ return new AQDistanceComputerDecompress<VD>(*this, vd);
210
+ } else {
211
+ FAISS_THROW_MSG("unsupported metric");
212
+ }
213
+ } else {
214
+ if (metric_type == METRIC_INNER_PRODUCT) {
215
+ return new AQDistanceComputerLUT<true, AdditiveQuantizer::ST_LUT_nonorm>(*this);
216
+ } else {
217
+ switch(aq->search_type) {
218
+ #define DISPATCH(st) \
219
+ case AdditiveQuantizer::st: \
220
+ return new AQDistanceComputerLUT<false, AdditiveQuantizer::st> (*this);\
221
+ break;
222
+ DISPATCH(ST_norm_float)
223
+ DISPATCH(ST_LUT_nonorm)
224
+ DISPATCH(ST_norm_qint8)
225
+ DISPATCH(ST_norm_qint4)
226
+ DISPATCH(ST_norm_cqint4)
227
+ case AdditiveQuantizer::ST_norm_cqint8:
228
+ case AdditiveQuantizer::ST_norm_lsq2x4:
229
+ case AdditiveQuantizer::ST_norm_rq2x4:
230
+ return new AQDistanceComputerLUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this);\
231
+ break;
232
+ #undef DISPATCH
233
+ default:
234
+ FAISS_THROW_FMT("search type %d not supported", aq->search_type);
235
+ }
236
+ }
237
+ }
238
+ }
239
+
240
+
241
+
242
+
114
243
  void IndexAdditiveQuantizer::search(
115
244
  idx_t n,
116
245
  const float* x,
117
246
  idx_t k,
118
247
  float* distances,
119
- idx_t* labels) const {
248
+ idx_t* labels,
249
+ const SearchParameters* params) const {
250
+
251
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
252
+
120
253
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
121
254
  if (metric_type == METRIC_L2) {
122
255
  using VD = VectorDistance<METRIC_L2>;
@@ -135,20 +268,23 @@ void IndexAdditiveQuantizer::search(
135
268
  search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm> (*this, x, rh);
136
269
  } else {
137
270
  HeapResultHandler<CMax<float, idx_t> > rh(n, distances, labels, k);
138
-
139
- if (aq->search_type == AdditiveQuantizer::ST_norm_float) {
140
- search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
141
- } else if (aq->search_type == AdditiveQuantizer::ST_LUT_nonorm) {
142
- search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
143
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint8) {
144
- search_with_LUT<false, AdditiveQuantizer::ST_norm_qint8> (*this, x, rh);
145
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint4) {
146
- search_with_LUT<false, AdditiveQuantizer::ST_norm_qint4> (*this, x, rh);
147
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8) {
271
+ switch(aq->search_type) {
272
+ #define DISPATCH(st) \
273
+ case AdditiveQuantizer::st: \
274
+ search_with_LUT<false, AdditiveQuantizer::st> (*this, x, rh);\
275
+ break;
276
+ DISPATCH(ST_norm_float)
277
+ DISPATCH(ST_LUT_nonorm)
278
+ DISPATCH(ST_norm_qint8)
279
+ DISPATCH(ST_norm_qint4)
280
+ DISPATCH(ST_norm_cqint4)
281
+ case AdditiveQuantizer::ST_norm_cqint8:
282
+ case AdditiveQuantizer::ST_norm_lsq2x4:
283
+ case AdditiveQuantizer::ST_norm_rq2x4:
148
284
  search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this, x, rh);
149
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
150
- search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint4> (*this, x, rh);
151
- } else {
285
+ break;
286
+ #undef DISPATCH
287
+ default:
152
288
  FAISS_THROW_FMT("search type %d not supported", aq->search_type);
153
289
  }
154
290
  }
@@ -220,6 +356,57 @@ void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
220
356
  is_trained = true;
221
357
  }
222
358
 
359
+
360
+ /**************************************************************************************
361
+ * IndexProductResidualQuantizer
362
+ **************************************************************************************/
363
+
364
+ IndexProductResidualQuantizer::IndexProductResidualQuantizer(
365
+ int d, ///< dimensionality of the input vectors
366
+ size_t nsplits, ///< number of residual quantizers
367
+ size_t Msub, ///< number of subquantizers per RQ
368
+ size_t nbits, ///< number of bit per subvector index
369
+ MetricType metric,
370
+ Search_type_t search_type)
371
+ : IndexAdditiveQuantizer(d, &prq, metric), prq(d, nsplits, Msub, nbits, search_type) {
372
+ code_size = prq.code_size;
373
+ is_trained = false;
374
+ }
375
+
376
+ IndexProductResidualQuantizer::IndexProductResidualQuantizer()
377
+ : IndexProductResidualQuantizer(0, 0, 0, 0) {}
378
+
379
+ void IndexProductResidualQuantizer::train(idx_t n, const float* x) {
380
+ prq.train(n, x);
381
+ is_trained = true;
382
+ }
383
+
384
+
385
+ /**************************************************************************************
386
+ * IndexProductLocalSearchQuantizer
387
+ **************************************************************************************/
388
+
389
+ IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer(
390
+ int d, ///< dimensionality of the input vectors
391
+ size_t nsplits, ///< number of local search quantizers
392
+ size_t Msub, ///< number of subquantizers per LSQ
393
+ size_t nbits, ///< number of bit per subvector index
394
+ MetricType metric,
395
+ Search_type_t search_type)
396
+ : IndexAdditiveQuantizer(d, &plsq, metric), plsq(d, nsplits, Msub, nbits, search_type) {
397
+ code_size = plsq.code_size;
398
+ is_trained = false;
399
+ }
400
+
401
+ IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer()
402
+ : IndexProductLocalSearchQuantizer(0, 0, 0, 0) {}
403
+
404
+ void IndexProductLocalSearchQuantizer::train(idx_t n, const float* x) {
405
+ plsq.train(n, x);
406
+ is_trained = true;
407
+ }
408
+
409
+
223
410
  /**************************************************************************************
224
411
  * AdditiveCoarseQuantizer
225
412
  **************************************************************************************/
@@ -248,6 +435,13 @@ void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
248
435
  if (verbose) {
249
436
  printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n", size_t(n));
250
437
  }
438
+ size_t norms_size = sizeof(float) << aq->tot_bits;
439
+
440
+ FAISS_THROW_IF_NOT_MSG (
441
+ norms_size <= aq->max_mem_distances,
442
+ "the RCQ norms matrix will become too large, please reduce the number of quantization steps"
443
+ );
444
+
251
445
  aq->train(n, x);
252
446
  is_trained = true;
253
447
  ntotal = (idx_t)1 << aq->tot_bits;
@@ -268,7 +462,11 @@ void AdditiveCoarseQuantizer::search(
268
462
  const float* x,
269
463
  idx_t k,
270
464
  float* distances,
271
- idx_t* labels) const {
465
+ idx_t* labels,
466
+ const SearchParameters * params) const {
467
+
468
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
469
+
272
470
  if (metric_type == METRIC_INNER_PRODUCT) {
273
471
  aq->knn_centroids_inner_product(n, x, k, distances, labels);
274
472
  } else if (metric_type == METRIC_L2) {
@@ -321,7 +519,12 @@ void ResidualCoarseQuantizer::search(
321
519
  const float* x,
322
520
  idx_t k,
323
521
  float* distances,
324
- idx_t* labels) const {
522
+ idx_t* labels,
523
+ const SearchParameters * params
524
+ ) const {
525
+
526
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
527
+
325
528
  if (beam_factor < 0) {
326
529
  AdditiveCoarseQuantizer::search(n, x, k, distances, labels);
327
530
  return;
@@ -15,6 +15,7 @@
15
15
 
16
16
  #include <faiss/IndexFlatCodes.h>
17
17
  #include <faiss/impl/LocalSearchQuantizer.h>
18
+ #include <faiss/impl/ProductAdditiveQuantizer.h>
18
19
  #include <faiss/impl/ResidualQuantizer.h>
19
20
  #include <faiss/impl/platform_macros.h>
20
21
 
@@ -28,8 +29,8 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
28
29
  using Search_type_t = AdditiveQuantizer::Search_type_t;
29
30
 
30
31
  explicit IndexAdditiveQuantizer(
31
- idx_t d = 0,
32
- AdditiveQuantizer* aq = nullptr,
32
+ idx_t d,
33
+ AdditiveQuantizer* aq,
33
34
  MetricType metric = METRIC_L2);
34
35
 
35
36
  void search(
@@ -37,12 +38,15 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
37
38
  const float* x,
38
39
  idx_t k,
39
40
  float* distances,
40
- idx_t* labels) const override;
41
+ idx_t* labels,
42
+ const SearchParameters* params = nullptr) const override;
41
43
 
42
44
  /* The standalone codec interface */
43
45
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
44
46
 
45
47
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
48
+
49
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
46
50
  };
47
51
 
48
52
  /** Index based on a residual quantizer. Stored vectors are
@@ -98,6 +102,58 @@ struct IndexLocalSearchQuantizer : IndexAdditiveQuantizer {
98
102
  void train(idx_t n, const float* x) override;
99
103
  };
100
104
 
105
+ /** Index based on a product residual quantizer.
106
+ */
107
+ struct IndexProductResidualQuantizer : IndexAdditiveQuantizer {
108
+ /// The product residual quantizer used to encode the vectors
109
+ ProductResidualQuantizer prq;
110
+
111
+ /** Constructor.
112
+ *
113
+ * @param d dimensionality of the input vectors
114
+ * @param nsplits number of residual quantizers
115
+ * @param Msub number of subquantizers per RQ
116
+ * @param nbits number of bit per subvector index
117
+ */
118
+ IndexProductResidualQuantizer(
119
+ int d, ///< dimensionality of the input vectors
120
+ size_t nsplits, ///< number of residual quantizers
121
+ size_t Msub, ///< number of subquantizers per RQ
122
+ size_t nbits, ///< number of bit per subvector index
123
+ MetricType metric = METRIC_L2,
124
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
125
+
126
+ IndexProductResidualQuantizer();
127
+
128
+ void train(idx_t n, const float* x) override;
129
+ };
130
+
131
+ /** Index based on a product local search quantizer.
132
+ */
133
+ struct IndexProductLocalSearchQuantizer : IndexAdditiveQuantizer {
134
+ /// The product local search quantizer used to encode the vectors
135
+ ProductLocalSearchQuantizer plsq;
136
+
137
+ /** Constructor.
138
+ *
139
+ * @param d dimensionality of the input vectors
140
+ * @param nsplits number of local search quantizers
141
+ * @param Msub number of subquantizers per LSQ
142
+ * @param nbits number of bit per subvector index
143
+ */
144
+ IndexProductLocalSearchQuantizer(
145
+ int d, ///< dimensionality of the input vectors
146
+ size_t nsplits, ///< number of local search quantizers
147
+ size_t Msub, ///< number of subquantizers per LSQ
148
+ size_t nbits, ///< number of bit per subvector index
149
+ MetricType metric = METRIC_L2,
150
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
151
+
152
+ IndexProductLocalSearchQuantizer();
153
+
154
+ void train(idx_t n, const float* x) override;
155
+ };
156
+
101
157
  /** A "virtual" index where the elements are the residual quantizer centroids.
102
158
  *
103
159
  * Intended for use as a coarse quantizer in an IndexIVF.
@@ -121,7 +177,8 @@ struct AdditiveCoarseQuantizer : Index {
121
177
  const float* x,
122
178
  idx_t k,
123
179
  float* distances,
124
- idx_t* labels) const override;
180
+ idx_t* labels,
181
+ const SearchParameters* params = nullptr) const override;
125
182
 
126
183
  void reconstruct(idx_t key, float* recons) const override;
127
184
  void train(idx_t n, const float* x) override;
@@ -166,7 +223,8 @@ struct ResidualCoarseQuantizer : AdditiveCoarseQuantizer {
166
223
  const float* x,
167
224
  idx_t k,
168
225
  float* distances,
169
- idx_t* labels) const override;
226
+ idx_t* labels,
227
+ const SearchParameters* params = nullptr) const override;
170
228
 
171
229
  ResidualCoarseQuantizer();
172
230
  };
@@ -0,0 +1,299 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/IndexAdditiveQuantizerFastScan.h>
9
+
10
+ #include <limits.h>
11
+ #include <cassert>
12
+ #include <memory>
13
+
14
+ #include <omp.h>
15
+
16
+ #include <faiss/impl/FaissAssert.h>
17
+ #include <faiss/impl/LocalSearchQuantizer.h>
18
+ #include <faiss/impl/LookupTableScaler.h>
19
+ #include <faiss/impl/ResidualQuantizer.h>
20
+ #include <faiss/impl/pq4_fast_scan.h>
21
+ #include <faiss/utils/quantize_lut.h>
22
+ #include <faiss/utils/utils.h>
23
+
24
+ namespace faiss {
25
+
26
+ inline size_t roundup(size_t a, size_t b) {
27
+ return (a + b - 1) / b * b;
28
+ }
29
+
30
+ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
31
+ AdditiveQuantizer* aq,
32
+ MetricType metric,
33
+ int bbs) {
34
+ init(aq, metric, bbs);
35
+ }
36
+
37
+ void IndexAdditiveQuantizerFastScan::init(
38
+ AdditiveQuantizer* aq,
39
+ MetricType metric,
40
+ int bbs) {
41
+ FAISS_THROW_IF_NOT(aq != nullptr);
42
+ FAISS_THROW_IF_NOT(!aq->nbits.empty());
43
+ FAISS_THROW_IF_NOT(aq->nbits[0] == 4);
44
+ if (metric == METRIC_INNER_PRODUCT) {
45
+ FAISS_THROW_IF_NOT_MSG(
46
+ aq->search_type == AdditiveQuantizer::ST_LUT_nonorm,
47
+ "Search type must be ST_LUT_nonorm for IP metric");
48
+ } else {
49
+ FAISS_THROW_IF_NOT_MSG(
50
+ aq->search_type == AdditiveQuantizer::ST_norm_lsq2x4 ||
51
+ aq->search_type == AdditiveQuantizer::ST_norm_rq2x4,
52
+ "Search type must be lsq2x4 or rq2x4 for L2 metric");
53
+ }
54
+
55
+ this->aq = aq;
56
+ if (metric == METRIC_L2) {
57
+ M = aq->M + 2; // 2x4 bits AQ
58
+ } else {
59
+ M = aq->M;
60
+ }
61
+ init_fastscan(aq->d, M, 4, metric, bbs);
62
+
63
+ max_train_points = 1024 * ksub * M;
64
+ }
65
+
66
+ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan()
67
+ : IndexFastScan() {
68
+ is_trained = false;
69
+ aq = nullptr;
70
+ }
71
+
72
+ IndexAdditiveQuantizerFastScan::IndexAdditiveQuantizerFastScan(
73
+ const IndexAdditiveQuantizer& orig,
74
+ int bbs) {
75
+ init(orig.aq, orig.metric_type, bbs);
76
+
77
+ ntotal = orig.ntotal;
78
+ is_trained = orig.is_trained;
79
+ orig_codes = orig.codes.data();
80
+
81
+ ntotal2 = roundup(ntotal, bbs);
82
+ codes.resize(ntotal2 * M2 / 2);
83
+ pq4_pack_codes(orig_codes, ntotal, M, ntotal2, bbs, M2, codes.get());
84
+ }
85
+
86
+ IndexAdditiveQuantizerFastScan::~IndexAdditiveQuantizerFastScan() {}
87
+
88
+ void IndexAdditiveQuantizerFastScan::train(idx_t n, const float* x_in) {
89
+ if (is_trained) {
90
+ return;
91
+ }
92
+
93
+ const int seed = 0x12345;
94
+ size_t nt = n;
95
+ const float* x = fvecs_maybe_subsample(
96
+ d, &nt, max_train_points, x_in, verbose, seed);
97
+ n = nt;
98
+ if (verbose) {
99
+ printf("training additive quantizer on %zd vectors\n", nt);
100
+ }
101
+
102
+ aq->verbose = verbose;
103
+ aq->train(n, x);
104
+ if (metric_type == METRIC_L2) {
105
+ estimate_norm_scale(n, x);
106
+ }
107
+
108
+ is_trained = true;
109
+ }
110
+
111
+ void IndexAdditiveQuantizerFastScan::estimate_norm_scale(
112
+ idx_t n,
113
+ const float* x_in) {
114
+ FAISS_THROW_IF_NOT(metric_type == METRIC_L2);
115
+
116
+ constexpr int seed = 0x980903;
117
+ constexpr size_t max_points_estimated = 65536;
118
+ size_t ns = n;
119
+ const float* x = fvecs_maybe_subsample(
120
+ d, &ns, max_points_estimated, x_in, verbose, seed);
121
+ n = ns;
122
+ std::unique_ptr<float[]> del_x;
123
+ if (x != x_in) {
124
+ del_x.reset((float*)x);
125
+ }
126
+
127
+ std::vector<float> dis_tables(n * M * ksub);
128
+ compute_float_LUT(dis_tables.data(), n, x);
129
+
130
+ // here we compute the mean of scales for each query
131
+ // TODO: try max of scales
132
+ double scale = 0;
133
+
134
+ #pragma omp parallel for reduction(+ : scale)
135
+ for (idx_t i = 0; i < n; i++) {
136
+ const float* lut = dis_tables.data() + i * M * ksub;
137
+ scale += quantize_lut::aq_estimate_norm_scale(M, ksub, 2, lut);
138
+ }
139
+ scale /= n;
140
+ norm_scale = (int)std::roundf(std::max(scale, 1.0));
141
+
142
+ if (verbose) {
143
+ printf("estimated norm scale: %lf\n", scale);
144
+ printf("rounded norm scale: %d\n", norm_scale);
145
+ }
146
+ }
147
+
148
+ void IndexAdditiveQuantizerFastScan::compute_codes(
149
+ uint8_t* tmp_codes,
150
+ idx_t n,
151
+ const float* x) const {
152
+ aq->compute_codes(x, tmp_codes, n);
153
+ }
154
+
155
+ void IndexAdditiveQuantizerFastScan::compute_float_LUT(
156
+ float* lut,
157
+ idx_t n,
158
+ const float* x) const {
159
+ if (metric_type == METRIC_INNER_PRODUCT) {
160
+ aq->compute_LUT(n, x, lut, 1.0f);
161
+ } else {
162
+ // compute inner product look-up tables
163
+ const size_t ip_dim12 = aq->M * ksub;
164
+ const size_t norm_dim12 = 2 * ksub;
165
+ std::vector<float> ip_lut(n * ip_dim12);
166
+ aq->compute_LUT(n, x, ip_lut.data(), -2.0f);
167
+
168
+ // copy and rescale norm look-up tables
169
+ auto norm_tabs = aq->norm_tabs;
170
+ if (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2) {
171
+ for (size_t i = 0; i < norm_tabs.size(); i++) {
172
+ norm_tabs[i] /= norm_scale;
173
+ }
174
+ }
175
+ const float* norm_lut = norm_tabs.data();
176
+ FAISS_THROW_IF_NOT(norm_tabs.size() == norm_dim12);
177
+
178
+ // combine them
179
+ for (idx_t i = 0; i < n; i++) {
180
+ memcpy(lut, ip_lut.data() + i * ip_dim12, ip_dim12 * sizeof(*lut));
181
+ lut += ip_dim12;
182
+ memcpy(lut, norm_lut, norm_dim12 * sizeof(*lut));
183
+ lut += norm_dim12;
184
+ }
185
+ }
186
+ }
187
+
188
+ void IndexAdditiveQuantizerFastScan::search(
189
+ idx_t n,
190
+ const float* x,
191
+ idx_t k,
192
+ float* distances,
193
+ idx_t* labels,
194
+ const SearchParameters* params) const {
195
+ FAISS_THROW_IF_NOT_MSG(
196
+ !params, "search params not supported for this index");
197
+ FAISS_THROW_IF_NOT(k > 0);
198
+ bool rescale = (rescale_norm && norm_scale > 1 && metric_type == METRIC_L2);
199
+ if (!rescale) {
200
+ IndexFastScan::search(n, x, k, distances, labels);
201
+ return;
202
+ }
203
+
204
+ NormTableScaler scaler(norm_scale);
205
+ if (metric_type == METRIC_L2) {
206
+ search_dispatch_implem<true>(n, x, k, distances, labels, scaler);
207
+ } else {
208
+ search_dispatch_implem<false>(n, x, k, distances, labels, scaler);
209
+ }
210
+ }
211
+
212
+ void IndexAdditiveQuantizerFastScan::sa_decode(
213
+ idx_t n,
214
+ const uint8_t* bytes,
215
+ float* x) const {
216
+ aq->decode(bytes, x, n);
217
+ }
218
+
219
+ /**************************************************************************************
220
+ * IndexResidualQuantizerFastScan
221
+ **************************************************************************************/
222
+
223
+ IndexResidualQuantizerFastScan::IndexResidualQuantizerFastScan(
224
+ int d, ///< dimensionality of the input vectors
225
+ size_t M, ///< number of subquantizers
226
+ size_t nbits, ///< number of bit per subvector index
227
+ MetricType metric,
228
+ Search_type_t search_type,
229
+ int bbs)
230
+ : rq(d, M, nbits, search_type) {
231
+ init(&rq, metric, bbs);
232
+ }
233
+
234
+ IndexResidualQuantizerFastScan::IndexResidualQuantizerFastScan() {
235
+ aq = &rq;
236
+ }
237
+
238
+ /**************************************************************************************
239
+ * IndexLocalSearchQuantizerFastScan
240
+ **************************************************************************************/
241
+
242
+ IndexLocalSearchQuantizerFastScan::IndexLocalSearchQuantizerFastScan(
243
+ int d,
244
+ size_t M, ///< number of subquantizers
245
+ size_t nbits, ///< number of bit per subvector index
246
+ MetricType metric,
247
+ Search_type_t search_type,
248
+ int bbs)
249
+ : lsq(d, M, nbits, search_type) {
250
+ init(&lsq, metric, bbs);
251
+ }
252
+
253
+ IndexLocalSearchQuantizerFastScan::IndexLocalSearchQuantizerFastScan() {
254
+ aq = &lsq;
255
+ }
256
+
257
+ /**************************************************************************************
258
+ * IndexProductResidualQuantizerFastScan
259
+ **************************************************************************************/
260
+
261
+ IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan(
262
+ int d, ///< dimensionality of the input vectors
263
+ size_t nsplits, ///< number of residual quantizers
264
+ size_t Msub, ///< number of subquantizers per RQ
265
+ size_t nbits, ///< number of bit per subvector index
266
+ MetricType metric,
267
+ Search_type_t search_type,
268
+ int bbs)
269
+ : prq(d, nsplits, Msub, nbits, search_type) {
270
+ init(&prq, metric, bbs);
271
+ }
272
+
273
+ IndexProductResidualQuantizerFastScan::IndexProductResidualQuantizerFastScan() {
274
+ aq = &prq;
275
+ }
276
+
277
+ /**************************************************************************************
278
+ * IndexProductLocalSearchQuantizerFastScan
279
+ **************************************************************************************/
280
+
281
+ IndexProductLocalSearchQuantizerFastScan::
282
+ IndexProductLocalSearchQuantizerFastScan(
283
+ int d, ///< dimensionality of the input vectors
284
+ size_t nsplits, ///< number of local search quantizers
285
+ size_t Msub, ///< number of subquantizers per LSQ
286
+ size_t nbits, ///< number of bit per subvector index
287
+ MetricType metric,
288
+ Search_type_t search_type,
289
+ int bbs)
290
+ : plsq(d, nsplits, Msub, nbits, search_type) {
291
+ init(&plsq, metric, bbs);
292
+ }
293
+
294
+ IndexProductLocalSearchQuantizerFastScan::
295
+ IndexProductLocalSearchQuantizerFastScan() {
296
+ aq = &plsq;
297
+ }
298
+
299
+ } // namespace faiss