faiss 0.2.4 → 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +23 -21
  5. data/ext/faiss/extconf.rb +11 -0
  6. data/ext/faiss/index.cpp +17 -4
  7. data/ext/faiss/index_binary.cpp +6 -6
  8. data/ext/faiss/product_quantizer.cpp +4 -4
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  11. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  12. data/vendor/faiss/faiss/IVFlib.h +26 -2
  13. data/vendor/faiss/faiss/Index.cpp +36 -3
  14. data/vendor/faiss/faiss/Index.h +43 -6
  15. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  16. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  21. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  22. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  23. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  24. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  26. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  28. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  29. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  30. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  31. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  32. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  33. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  34. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  35. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  36. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  37. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  38. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  39. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  40. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  41. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  42. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  43. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  44. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  48. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  50. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  51. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  52. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  53. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  54. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  56. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  57. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  58. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  60. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  61. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  62. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  63. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  64. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  65. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  66. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  67. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  68. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  69. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  70. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  71. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  72. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  73. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  74. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  75. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  76. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  78. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  80. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  82. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  83. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  84. data/vendor/faiss/faiss/IndexShards.h +2 -1
  85. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  86. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  87. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  88. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  89. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  90. data/vendor/faiss/faiss/clone_index.h +3 -0
  91. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  93. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  101. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  102. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  103. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  105. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  106. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  110. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  111. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  112. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  113. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  114. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  116. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  118. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  119. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  120. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  122. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  124. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  125. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  127. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  128. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  129. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  131. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  132. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  133. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  134. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  136. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  138. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  139. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  141. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  142. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  144. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  145. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  146. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  147. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  151. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  152. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  153. data/vendor/faiss/faiss/index_io.h +5 -0
  154. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  155. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  156. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  157. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  158. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  159. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  160. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  161. data/vendor/faiss/faiss/utils/distances.h +113 -15
  162. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  163. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  164. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  165. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  166. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  167. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  168. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  169. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  170. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  172. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  173. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  174. data/vendor/faiss/faiss/utils/random.h +5 -0
  175. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  176. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  177. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  178. metadata +37 -3
@@ -111,7 +111,8 @@ void Index2Layer::search(
111
111
  const float* /*x*/,
112
112
  idx_t /*k*/,
113
113
  float* /*distances*/,
114
- idx_t* /*labels*/) const {
114
+ idx_t* /*labels*/,
115
+ const SearchParameters* /* params */) const {
115
116
  FAISS_THROW_MSG("not implemented");
116
117
  }
117
118
 
@@ -282,10 +283,13 @@ DistanceComputer* Index2Layer::get_distance_computer() const {
282
283
 
283
284
  /* The standalone codec interface */
284
285
 
286
+ // block size used in Index2Layer::sa_encode
287
+ int index2layer_sa_encode_bs = 32768;
288
+
285
289
  void Index2Layer::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
286
290
  FAISS_THROW_IF_NOT(is_trained);
287
291
 
288
- idx_t bs = 32768;
292
+ idx_t bs = index2layer_sa_encode_bs;
289
293
  if (n > bs) {
290
294
  for (idx_t i0 = 0; i0 < n; i0 += bs) {
291
295
  idx_t i1 = std::min(i0 + bs, n);
@@ -14,6 +14,7 @@
14
14
  #include <faiss/IndexFlatCodes.h>
15
15
  #include <faiss/IndexIVF.h>
16
16
  #include <faiss/IndexPQ.h>
17
+ #include <faiss/impl/platform_macros.h>
17
18
 
18
19
  namespace faiss {
19
20
 
@@ -56,7 +57,8 @@ struct Index2Layer : IndexFlatCodes {
56
57
  const float* x,
57
58
  idx_t k,
58
59
  float* distances,
59
- idx_t* labels) const override;
60
+ idx_t* labels,
61
+ const SearchParameters* params = nullptr) const override;
60
62
 
61
63
  DistanceComputer* get_distance_computer() const override;
62
64
 
@@ -68,4 +70,7 @@ struct Index2Layer : IndexFlatCodes {
68
70
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
69
71
  };
70
72
 
73
+ // block size used in Index2Layer::sa_encode
74
+ FAISS_API extern int index2layer_sa_encode_bs;
75
+
71
76
  } // namespace faiss
@@ -40,6 +40,90 @@ IndexAdditiveQuantizer::IndexAdditiveQuantizer(
40
40
 
41
41
  namespace {
42
42
 
43
+ /************************************************************
44
+ * DistanceComputer implementation
45
+ ************************************************************/
46
+
47
+ template <class VectorDistance>
48
+ struct AQDistanceComputerDecompress: FlatCodesDistanceComputer {
49
+ std::vector<float> tmp;
50
+ const AdditiveQuantizer & aq;
51
+ VectorDistance vd;
52
+ size_t d;
53
+
54
+ AQDistanceComputerDecompress(const IndexAdditiveQuantizer &iaq, VectorDistance vd):
55
+ FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
56
+ tmp(iaq.d * 2),
57
+ aq(*iaq.aq),
58
+ vd(vd),
59
+ d(iaq.d)
60
+ {}
61
+
62
+ const float *q;
63
+ void set_query(const float* x) final {
64
+ q = x;
65
+ }
66
+
67
+ float symmetric_dis(idx_t i, idx_t j) final {
68
+ aq.decode(codes + i * d, tmp.data(), 1);
69
+ aq.decode(codes + j * d, tmp.data() + d, 1);
70
+ return vd(tmp.data(), tmp.data() + d);
71
+ }
72
+
73
+ float distance_to_code(const uint8_t *code) final {
74
+ aq.decode(code, tmp.data(), 1);
75
+ return vd(q, tmp.data());
76
+ }
77
+
78
+ virtual ~AQDistanceComputerDecompress() {}
79
+ };
80
+
81
+
82
+ template<bool is_IP, AdditiveQuantizer::Search_type_t st>
83
+ struct AQDistanceComputerLUT: FlatCodesDistanceComputer {
84
+ std::vector<float> LUT;
85
+ const AdditiveQuantizer & aq;
86
+ size_t d;
87
+
88
+ explicit AQDistanceComputerLUT(const IndexAdditiveQuantizer &iaq):
89
+ FlatCodesDistanceComputer(iaq.codes.data(), iaq.code_size),
90
+ LUT(iaq.aq->total_codebook_size + iaq.d * 2),
91
+ aq(*iaq.aq),
92
+ d(iaq.d)
93
+ {}
94
+
95
+ float bias;
96
+ void set_query(const float* x) final {
97
+ // this is quite sub-optimal for multiple queries
98
+ aq.compute_LUT(1, x, LUT.data());
99
+ if (is_IP) {
100
+ bias = 0;
101
+ } else {
102
+ bias = fvec_norm_L2sqr(x, d);
103
+ }
104
+ }
105
+
106
+ float symmetric_dis(idx_t i, idx_t j) final {
107
+ float *tmp = LUT.data();
108
+ aq.decode(codes + i * d, tmp, 1);
109
+ aq.decode(codes + j * d, tmp + d, 1);
110
+ return fvec_L2sqr(tmp, tmp + d, d);
111
+ }
112
+
113
+ float distance_to_code(const uint8_t *code) final {
114
+ return bias + aq.compute_1_distance_LUT<is_IP, st>(code, LUT.data());
115
+ }
116
+
117
+ virtual ~AQDistanceComputerLUT() {}
118
+ };
119
+
120
+
121
+
122
+ /************************************************************
123
+ * scanning implementation for search
124
+ ************************************************************/
125
+
126
+
43
127
  template <class VectorDistance, class ResultHandler>
44
128
  void search_with_decompress(
45
129
  const IndexAdditiveQuantizer& ir,
@@ -111,12 +195,61 @@ void search_with_LUT(
111
195
 
112
196
  } // anonymous namespace
113
197
 
198
+
199
+ FlatCodesDistanceComputer * IndexAdditiveQuantizer::get_FlatCodesDistanceComputer() const {
200
+
201
+ if (aq->search_type == AdditiveQuantizer::ST_decompress) {
202
+ if (metric_type == METRIC_L2) {
203
+ using VD = VectorDistance<METRIC_L2>;
204
+ VD vd = {size_t(d), metric_arg};
205
+ return new AQDistanceComputerDecompress<VD>(*this, vd);
206
+ } else if (metric_type == METRIC_INNER_PRODUCT) {
207
+ using VD = VectorDistance<METRIC_INNER_PRODUCT>;
208
+ VD vd = {size_t(d), metric_arg};
209
+ return new AQDistanceComputerDecompress<VD>(*this, vd);
210
+ } else {
211
+ FAISS_THROW_MSG("unsupported metric");
212
+ }
213
+ } else {
214
+ if (metric_type == METRIC_INNER_PRODUCT) {
215
+ return new AQDistanceComputerLUT<true, AdditiveQuantizer::ST_LUT_nonorm>(*this);
216
+ } else {
217
+ switch(aq->search_type) {
218
+ #define DISPATCH(st) \
219
+ case AdditiveQuantizer::st: \
220
+ return new AQDistanceComputerLUT<false, AdditiveQuantizer::st> (*this);\
221
+ break;
222
+ DISPATCH(ST_norm_float)
223
+ DISPATCH(ST_LUT_nonorm)
224
+ DISPATCH(ST_norm_qint8)
225
+ DISPATCH(ST_norm_qint4)
226
+ DISPATCH(ST_norm_cqint4)
227
+ case AdditiveQuantizer::ST_norm_cqint8:
228
+ case AdditiveQuantizer::ST_norm_lsq2x4:
229
+ case AdditiveQuantizer::ST_norm_rq2x4:
230
+ return new AQDistanceComputerLUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this);\
231
+ break;
232
+ #undef DISPATCH
233
+ default:
234
+ FAISS_THROW_FMT("search type %d not supported", aq->search_type);
235
+ }
236
+ }
237
+ }
238
+ }
239
+
240
+
241
+
242
+
114
243
  void IndexAdditiveQuantizer::search(
115
244
  idx_t n,
116
245
  const float* x,
117
246
  idx_t k,
118
247
  float* distances,
119
- idx_t* labels) const {
248
+ idx_t* labels,
249
+ const SearchParameters* params) const {
250
+
251
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
252
+
120
253
  if (aq->search_type == AdditiveQuantizer::ST_decompress) {
121
254
  if (metric_type == METRIC_L2) {
122
255
  using VD = VectorDistance<METRIC_L2>;
@@ -135,20 +268,23 @@ void IndexAdditiveQuantizer::search(
135
268
  search_with_LUT<true, AdditiveQuantizer::ST_LUT_nonorm> (*this, x, rh);
136
269
  } else {
137
270
  HeapResultHandler<CMax<float, idx_t> > rh(n, distances, labels, k);
138
-
139
- if (aq->search_type == AdditiveQuantizer::ST_norm_float) {
140
- search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
141
- } else if (aq->search_type == AdditiveQuantizer::ST_LUT_nonorm) {
142
- search_with_LUT<false, AdditiveQuantizer::ST_norm_float> (*this, x, rh);
143
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint8) {
144
- search_with_LUT<false, AdditiveQuantizer::ST_norm_qint8> (*this, x, rh);
145
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_qint4) {
146
- search_with_LUT<false, AdditiveQuantizer::ST_norm_qint4> (*this, x, rh);
147
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint8) {
271
+ switch(aq->search_type) {
272
+ #define DISPATCH(st) \
273
+ case AdditiveQuantizer::st: \
274
+ search_with_LUT<false, AdditiveQuantizer::st> (*this, x, rh);\
275
+ break;
276
+ DISPATCH(ST_norm_float)
277
+ DISPATCH(ST_LUT_nonorm)
278
+ DISPATCH(ST_norm_qint8)
279
+ DISPATCH(ST_norm_qint4)
280
+ DISPATCH(ST_norm_cqint4)
281
+ case AdditiveQuantizer::ST_norm_cqint8:
282
+ case AdditiveQuantizer::ST_norm_lsq2x4:
283
+ case AdditiveQuantizer::ST_norm_rq2x4:
148
284
  search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint8> (*this, x, rh);
149
- } else if (aq->search_type == AdditiveQuantizer::ST_norm_cqint4) {
150
- search_with_LUT<false, AdditiveQuantizer::ST_norm_cqint4> (*this, x, rh);
151
- } else {
285
+ break;
286
+ #undef DISPATCH
287
+ default:
152
288
  FAISS_THROW_FMT("search type %d not supported", aq->search_type);
153
289
  }
154
290
  }
@@ -220,6 +356,57 @@ void IndexLocalSearchQuantizer::train(idx_t n, const float* x) {
220
356
  is_trained = true;
221
357
  }
222
358
 
359
+
360
+ /**************************************************************************************
361
+ * IndexProductResidualQuantizer
362
+ **************************************************************************************/
363
+
364
+ IndexProductResidualQuantizer::IndexProductResidualQuantizer(
365
+ int d, ///< dimensionality of the input vectors
366
+ size_t nsplits, ///< number of residual quantizers
367
+ size_t Msub, ///< number of subquantizers per RQ
368
+ size_t nbits, ///< number of bit per subvector index
369
+ MetricType metric,
370
+ Search_type_t search_type)
371
+ : IndexAdditiveQuantizer(d, &prq, metric), prq(d, nsplits, Msub, nbits, search_type) {
372
+ code_size = prq.code_size;
373
+ is_trained = false;
374
+ }
375
+
376
+ IndexProductResidualQuantizer::IndexProductResidualQuantizer()
377
+ : IndexProductResidualQuantizer(0, 0, 0, 0) {}
378
+
379
+ void IndexProductResidualQuantizer::train(idx_t n, const float* x) {
380
+ prq.train(n, x);
381
+ is_trained = true;
382
+ }
383
+
384
+
385
+ /**************************************************************************************
386
+ * IndexProductLocalSearchQuantizer
387
+ **************************************************************************************/
388
+
389
+ IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer(
390
+ int d, ///< dimensionality of the input vectors
391
+ size_t nsplits, ///< number of local search quantizers
392
+ size_t Msub, ///< number of subquantizers per LSQ
393
+ size_t nbits, ///< number of bit per subvector index
394
+ MetricType metric,
395
+ Search_type_t search_type)
396
+ : IndexAdditiveQuantizer(d, &plsq, metric), plsq(d, nsplits, Msub, nbits, search_type) {
397
+ code_size = plsq.code_size;
398
+ is_trained = false;
399
+ }
400
+
401
+ IndexProductLocalSearchQuantizer::IndexProductLocalSearchQuantizer()
402
+ : IndexProductLocalSearchQuantizer(0, 0, 0, 0) {}
403
+
404
+ void IndexProductLocalSearchQuantizer::train(idx_t n, const float* x) {
405
+ plsq.train(n, x);
406
+ is_trained = true;
407
+ }
408
+
409
+
223
410
  /**************************************************************************************
224
411
  * AdditiveCoarseQuantizer
225
412
  **************************************************************************************/
@@ -248,6 +435,13 @@ void AdditiveCoarseQuantizer::train(idx_t n, const float* x) {
248
435
  if (verbose) {
249
436
  printf("AdditiveCoarseQuantizer::train: training on %zd vectors\n", size_t(n));
250
437
  }
438
+ size_t norms_size = sizeof(float) << aq->tot_bits;
439
+
440
+ FAISS_THROW_IF_NOT_MSG (
441
+ norms_size <= aq->max_mem_distances,
442
+ "the RCQ norms matrix will become too large, please reduce the number of quantization steps"
443
+ );
444
+
251
445
  aq->train(n, x);
252
446
  is_trained = true;
253
447
  ntotal = (idx_t)1 << aq->tot_bits;
@@ -268,7 +462,11 @@ void AdditiveCoarseQuantizer::search(
268
462
  const float* x,
269
463
  idx_t k,
270
464
  float* distances,
271
- idx_t* labels) const {
465
+ idx_t* labels,
466
+ const SearchParameters * params) const {
467
+
468
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
469
+
272
470
  if (metric_type == METRIC_INNER_PRODUCT) {
273
471
  aq->knn_centroids_inner_product(n, x, k, distances, labels);
274
472
  } else if (metric_type == METRIC_L2) {
@@ -321,7 +519,12 @@ void ResidualCoarseQuantizer::search(
321
519
  const float* x,
322
520
  idx_t k,
323
521
  float* distances,
324
- idx_t* labels) const {
522
+ idx_t* labels,
523
+ const SearchParameters * params
524
+ ) const {
525
+
526
+ FAISS_THROW_IF_NOT_MSG(!params, "search params not supported for this index");
527
+
325
528
  if (beam_factor < 0) {
326
529
  AdditiveCoarseQuantizer::search(n, x, k, distances, labels);
327
530
  return;
@@ -15,6 +15,7 @@
15
15
 
16
16
  #include <faiss/IndexFlatCodes.h>
17
17
  #include <faiss/impl/LocalSearchQuantizer.h>
18
+ #include <faiss/impl/ProductAdditiveQuantizer.h>
18
19
  #include <faiss/impl/ResidualQuantizer.h>
19
20
  #include <faiss/impl/platform_macros.h>
20
21
 
@@ -28,8 +29,8 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
28
29
  using Search_type_t = AdditiveQuantizer::Search_type_t;
29
30
 
30
31
  explicit IndexAdditiveQuantizer(
31
- idx_t d = 0,
32
- AdditiveQuantizer* aq = nullptr,
32
+ idx_t d,
33
+ AdditiveQuantizer* aq,
33
34
  MetricType metric = METRIC_L2);
34
35
 
35
36
  void search(
@@ -37,12 +38,15 @@ struct IndexAdditiveQuantizer : IndexFlatCodes {
37
38
  const float* x,
38
39
  idx_t k,
39
40
  float* distances,
40
- idx_t* labels) const override;
41
+ idx_t* labels,
42
+ const SearchParameters* params = nullptr) const override;
41
43
 
42
44
  /* The standalone codec interface */
43
45
  void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
44
46
 
45
47
  void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
48
+
49
+ FlatCodesDistanceComputer* get_FlatCodesDistanceComputer() const override;
46
50
  };
47
51
 
48
52
  /** Index based on a residual quantizer. Stored vectors are
@@ -98,6 +102,58 @@ struct IndexLocalSearchQuantizer : IndexAdditiveQuantizer {
98
102
  void train(idx_t n, const float* x) override;
99
103
  };
100
104
 
105
+ /** Index based on a product residual quantizer.
106
+ */
107
+ struct IndexProductResidualQuantizer : IndexAdditiveQuantizer {
108
+ /// The product residual quantizer used to encode the vectors
109
+ ProductResidualQuantizer prq;
110
+
111
+ /** Constructor.
112
+ *
113
+ * @param d dimensionality of the input vectors
114
+ * @param nsplits number of residual quantizers
115
+ * @param Msub number of subquantizers per RQ
116
+ * @param nbits number of bit per subvector index
117
+ */
118
+ IndexProductResidualQuantizer(
119
+ int d, ///< dimensionality of the input vectors
120
+ size_t nsplits, ///< number of residual quantizers
121
+ size_t Msub, ///< number of subquantizers per RQ
122
+ size_t nbits, ///< number of bit per subvector index
123
+ MetricType metric = METRIC_L2,
124
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
125
+
126
+ IndexProductResidualQuantizer();
127
+
128
+ void train(idx_t n, const float* x) override;
129
+ };
130
+
131
+ /** Index based on a product local search quantizer.
132
+ */
133
+ struct IndexProductLocalSearchQuantizer : IndexAdditiveQuantizer {
134
+ /// The product local search quantizer used to encode the vectors
135
+ ProductLocalSearchQuantizer plsq;
136
+
137
+ /** Constructor.
138
+ *
139
+ * @param d dimensionality of the input vectors
140
+ * @param nsplits number of local search quantizers
141
+ * @param Msub number of subquantizers per LSQ
142
+ * @param nbits number of bit per subvector index
143
+ */
144
+ IndexProductLocalSearchQuantizer(
145
+ int d, ///< dimensionality of the input vectors
146
+ size_t nsplits, ///< number of local search quantizers
147
+ size_t Msub, ///< number of subquantizers per LSQ
148
+ size_t nbits, ///< number of bit per subvector index
149
+ MetricType metric = METRIC_L2,
150
+ Search_type_t search_type = AdditiveQuantizer::ST_decompress);
151
+
152
+ IndexProductLocalSearchQuantizer();
153
+
154
+ void train(idx_t n, const float* x) override;
155
+ };
156
+
101
157
  /** A "virtual" index where the elements are the residual quantizer centroids.
102
158
  *
103
159
  * Intended for use as a coarse quantizer in an IndexIVF.
@@ -121,7 +177,8 @@ struct AdditiveCoarseQuantizer : Index {
121
177
  const float* x,
122
178
  idx_t k,
123
179
  float* distances,
124
- idx_t* labels) const override;
180
+ idx_t* labels,
181
+ const SearchParameters* params = nullptr) const override;
125
182
 
126
183
  void reconstruct(idx_t key, float* recons) const override;
127
184
  void train(idx_t n, const float* x) override;
@@ -166,7 +223,8 @@ struct ResidualCoarseQuantizer : AdditiveCoarseQuantizer {
166
223
  const float* x,
167
224
  idx_t k,
168
225
  float* distances,
169
- idx_t* labels) const override;
226
+ idx_t* labels,
227
+ const SearchParameters* params = nullptr) const override;
170
228
 
171
229
  ResidualCoarseQuantizer();
172
230
  };