faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -17,7 +17,10 @@
17
17
 
18
18
  #include <algorithm>
19
19
 
20
+ #include <faiss/Clustering.h>
20
21
  #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/impl/LocalSearchQuantizer.h>
23
+ #include <faiss/impl/ResidualQuantizer.h>
21
24
  #include <faiss/utils/Heap.h>
22
25
  #include <faiss/utils/distances.h>
23
26
  #include <faiss/utils/hamming.h>
@@ -48,14 +51,14 @@ AdditiveQuantizer::AdditiveQuantizer(
48
51
  size_t d,
49
52
  const std::vector<size_t>& nbits,
50
53
  Search_type_t search_type)
51
- : d(d),
54
+ : Quantizer(d),
52
55
  M(nbits.size()),
53
56
  nbits(nbits),
54
57
  verbose(false),
55
58
  is_trained(false),
59
+ max_mem_distances(5 * (size_t(1) << 30)), // 5 GiB
56
60
  search_type(search_type) {
57
61
  norm_max = norm_min = NAN;
58
- code_size = 0;
59
62
  tot_bits = 0;
60
63
  total_codebook_size = 0;
61
64
  only_8bit = false;
@@ -80,27 +83,82 @@ void AdditiveQuantizer::set_derived_values() {
80
83
  }
81
84
  total_codebook_size = codebook_offsets[M];
82
85
  switch (search_type) {
83
- case ST_decompress:
84
- case ST_LUT_nonorm:
85
- case ST_norm_from_LUT:
86
- break; // nothing to add
87
86
  case ST_norm_float:
88
- tot_bits += 32;
87
+ norm_bits = 32;
89
88
  break;
90
89
  case ST_norm_qint8:
91
90
  case ST_norm_cqint8:
92
- tot_bits += 8;
91
+ case ST_norm_lsq2x4:
92
+ case ST_norm_rq2x4:
93
+ norm_bits = 8;
93
94
  break;
94
95
  case ST_norm_qint4:
95
96
  case ST_norm_cqint4:
96
- tot_bits += 4;
97
+ norm_bits = 4;
98
+ break;
99
+ case ST_decompress:
100
+ case ST_LUT_nonorm:
101
+ case ST_norm_from_LUT:
102
+ default:
103
+ norm_bits = 0;
97
104
  break;
98
105
  }
106
+ tot_bits += norm_bits;
99
107
 
100
108
  // convert bits to bytes
101
109
  code_size = (tot_bits + 7) / 8;
102
110
  }
103
111
 
112
+ void AdditiveQuantizer::train_norm(size_t n, const float* norms) {
113
+ norm_min = HUGE_VALF;
114
+ norm_max = -HUGE_VALF;
115
+ for (idx_t i = 0; i < n; i++) {
116
+ if (norms[i] < norm_min) {
117
+ norm_min = norms[i];
118
+ }
119
+ if (norms[i] > norm_max) {
120
+ norm_max = norms[i];
121
+ }
122
+ }
123
+
124
+ if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
125
+ size_t k = (1 << 8);
126
+ if (search_type == ST_norm_cqint4) {
127
+ k = (1 << 4);
128
+ }
129
+ Clustering1D clus(k);
130
+ clus.train_exact(n, norms);
131
+ qnorm.add(clus.k, clus.centroids.data());
132
+ } else if (search_type == ST_norm_lsq2x4 || search_type == ST_norm_rq2x4) {
133
+ std::unique_ptr<AdditiveQuantizer> aq;
134
+ if (search_type == ST_norm_lsq2x4) {
135
+ aq.reset(new LocalSearchQuantizer(1, 2, 4));
136
+ } else {
137
+ aq.reset(new ResidualQuantizer(1, 2, 4));
138
+ }
139
+
140
+ aq->train(n, norms);
141
+ // flatten aq codebooks
142
+ std::vector<float> flat_codebooks(1 << 8);
143
+ FAISS_THROW_IF_NOT(aq->codebooks.size() == 32);
144
+
145
+ // save norm tables for 4-bit fastscan search
146
+ norm_tabs = aq->codebooks;
147
+
148
+ // assume big endian
149
+ const float* c = norm_tabs.data();
150
+ for (size_t i = 0; i < 16; i++) {
151
+ for (size_t j = 0; j < 16; j++) {
152
+ flat_codebooks[i * 16 + j] = c[j] + c[16 + i];
153
+ }
154
+ }
155
+
156
+ qnorm.reset();
157
+ qnorm.add(1 << 8, flat_codebooks.data());
158
+ FAISS_THROW_IF_NOT(qnorm.ntotal == (1 << 8));
159
+ }
160
+ }
161
+
104
162
  namespace {
105
163
 
106
164
  // TODO
@@ -132,7 +190,7 @@ float decode_qint4(uint8_t i, float amin, float amax) {
132
190
 
133
191
  uint32_t AdditiveQuantizer::encode_qcint(float x) const {
134
192
  idx_t id;
135
- qnorm.assign(idx_t(1), &x, &id, idx_t(1));
193
+ qnorm.assign(1, &x, &id, 1);
136
194
  return uint32_t(id);
137
195
  }
138
196
 
@@ -140,23 +198,54 @@ float AdditiveQuantizer::decode_qcint(uint32_t c) const {
140
198
  return qnorm.get_xb()[c];
141
199
  }
142
200
 
201
+ uint64_t AdditiveQuantizer::encode_norm(float norm) const {
202
+ switch (search_type) {
203
+ case ST_norm_float:
204
+ uint32_t inorm;
205
+ memcpy(&inorm, &norm, 4);
206
+ return inorm;
207
+ case ST_norm_qint8:
208
+ return encode_qint8(norm, norm_min, norm_max);
209
+ case ST_norm_qint4:
210
+ return encode_qint4(norm, norm_min, norm_max);
211
+ case ST_norm_lsq2x4:
212
+ case ST_norm_rq2x4:
213
+ case ST_norm_cqint8:
214
+ return encode_qcint(norm);
215
+ case ST_norm_cqint4:
216
+ return encode_qcint(norm);
217
+ case ST_decompress:
218
+ case ST_LUT_nonorm:
219
+ case ST_norm_from_LUT:
220
+ default:
221
+ return 0;
222
+ }
223
+ }
224
+
143
225
  void AdditiveQuantizer::pack_codes(
144
226
  size_t n,
145
227
  const int32_t* codes,
146
228
  uint8_t* packed_codes,
147
229
  int64_t ld_codes,
148
- const float* norms) const {
230
+ const float* norms,
231
+ const float* centroids) const {
149
232
  if (ld_codes == -1) {
150
233
  ld_codes = M;
151
234
  }
152
235
  std::vector<float> norm_buf;
153
236
  if (search_type == ST_norm_float || search_type == ST_norm_qint4 ||
154
237
  search_type == ST_norm_qint8 || search_type == ST_norm_cqint8 ||
155
- search_type == ST_norm_cqint4) {
156
- if (!norms) {
238
+ search_type == ST_norm_cqint4 || search_type == ST_norm_lsq2x4 ||
239
+ search_type == ST_norm_rq2x4) {
240
+ if (centroids != nullptr || !norms) {
157
241
  norm_buf.resize(n);
158
242
  std::vector<float> x_recons(n * d);
159
243
  decode_unpacked(codes, x_recons.data(), n, ld_codes);
244
+
245
+ if (centroids != nullptr) {
246
+ // x = x + c
247
+ fvec_add(n * d, x_recons.data(), centroids, x_recons.data());
248
+ }
160
249
  fvec_norms_L2sqr(norm_buf.data(), x_recons.data(), d, n);
161
250
  norms = norm_buf.data();
162
251
  }
@@ -168,34 +257,8 @@ void AdditiveQuantizer::pack_codes(
168
257
  for (int m = 0; m < M; m++) {
169
258
  bsw.write(codes1[m], nbits[m]);
170
259
  }
171
- switch (search_type) {
172
- case ST_decompress:
173
- case ST_LUT_nonorm:
174
- case ST_norm_from_LUT:
175
- break;
176
- case ST_norm_float:
177
- bsw.write(*(uint32_t*)&norms[i], 32);
178
- break;
179
- case ST_norm_qint8: {
180
- uint8_t b = encode_qint8(norms[i], norm_min, norm_max);
181
- bsw.write(b, 8);
182
- break;
183
- }
184
- case ST_norm_qint4: {
185
- uint8_t b = encode_qint4(norms[i], norm_min, norm_max);
186
- bsw.write(b, 4);
187
- break;
188
- }
189
- case ST_norm_cqint8: {
190
- uint32_t b = encode_qcint(norms[i]);
191
- bsw.write(b, 8);
192
- break;
193
- }
194
- case ST_norm_cqint4: {
195
- uint32_t b = encode_qcint(norms[i]);
196
- bsw.write(b, 4);
197
- break;
198
- }
260
+ if (norm_bits != 0) {
261
+ bsw.write(encode_norm(norms[i]), norm_bits);
199
262
  }
200
263
  }
201
264
  }
@@ -283,28 +346,33 @@ void AdditiveQuantizer::decode_64bit(idx_t bits, float* xi) const {
283
346
  }
284
347
  }
285
348
 
286
- void AdditiveQuantizer::compute_LUT(size_t n, const float* xq, float* LUT)
287
- const {
349
+ void AdditiveQuantizer::compute_LUT(
350
+ size_t n,
351
+ const float* xq,
352
+ float* LUT,
353
+ float alpha,
354
+ long ld_lut) const {
288
355
  // in all cases, it is large matrix multiplication
289
356
 
290
357
  FINTEGER ncenti = total_codebook_size;
291
358
  FINTEGER di = d;
292
359
  FINTEGER nqi = n;
293
- float one = 1, zero = 0;
360
+ FINTEGER ldc = ld_lut > 0 ? ld_lut : ncenti;
361
+ float zero = 0;
294
362
 
295
363
  sgemm_("Transposed",
296
364
  "Not transposed",
297
365
  &ncenti,
298
366
  &nqi,
299
367
  &di,
300
- &one,
368
+ &alpha,
301
369
  codebooks.data(),
302
370
  &di,
303
371
  xq,
304
372
  &di,
305
373
  &zero,
306
374
  LUT,
307
- &ncenti);
375
+ &ldc);
308
376
  }
309
377
 
310
378
  namespace {
@@ -448,7 +516,8 @@ float AdditiveQuantizer::
448
516
  BitstringReader bs(codes, code_size);
449
517
  float accu = accumulate_IPs(*this, bs, codes, LUT);
450
518
  uint32_t norm_i = bs.read(32);
451
- float norm2 = *(float*)&norm_i;
519
+ float norm2;
520
+ memcpy(&norm2, &norm_i, 4);
452
521
  return norm2 - 2 * accu;
453
522
  }
454
523
 
@@ -12,6 +12,7 @@
12
12
 
13
13
  #include <faiss/Index.h>
14
14
  #include <faiss/IndexFlat.h>
15
+ #include <faiss/impl/Quantizer.h>
15
16
 
16
17
  namespace faiss {
17
18
 
@@ -21,23 +22,31 @@ namespace faiss {
21
22
  * concatenation of M sub-vectors, additive quantizers sum M sub-vectors
22
23
  * to get the decoded vector.
23
24
  */
24
- struct AdditiveQuantizer {
25
- size_t d; ///< size of the input vectors
25
+ struct AdditiveQuantizer : Quantizer {
26
26
  size_t M; ///< number of codebooks
27
27
  std::vector<size_t> nbits; ///< bits for each step
28
28
  std::vector<float> codebooks; ///< codebooks
29
29
 
30
30
  // derived values
31
31
  std::vector<uint64_t> codebook_offsets;
32
- size_t code_size; ///< code size in bytes
33
- size_t tot_bits; ///< total number of bits
32
+ size_t tot_bits; ///< total number of bits (indexes + norms)
33
+ size_t norm_bits; ///< bits allocated for the norms
34
34
  size_t total_codebook_size; ///< size of the codebook in vectors
35
35
  bool only_8bit; ///< are all nbits = 8 (use faster decoder)
36
36
 
37
37
  bool verbose; ///< verbose during training?
38
38
  bool is_trained; ///< is trained or not
39
39
 
40
- IndexFlat1D qnorm; ///< store and search norms
40
+ IndexFlat1D qnorm; ///< store and search norms
41
+ std::vector<float> norm_tabs; ///< store norms of codebook entries for 4-bit
42
+ ///< fastscan search
43
+
44
+ /// norms and distance matrixes with beam search can get large, so use this
45
+ /// to control for the amount of memory that can be allocated
46
+ size_t max_mem_distances;
47
+
48
+ /// encode a norm into norm_bits bits
49
+ uint64_t encode_norm(float norm) const;
41
50
 
42
51
  uint32_t encode_qcint(
43
52
  float x) const; ///< encode norm by non-uniform scalar quantization
@@ -57,6 +66,10 @@ struct AdditiveQuantizer {
57
66
  ST_norm_qint4,
58
67
  ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm
59
68
  ST_norm_cqint4,
69
+
70
+ ST_norm_lsq2x4, ///< use a 2x4 bits lsq as norm quantizer (for fast
71
+ ///< scan)
72
+ ST_norm_rq2x4, ///< use a 2x4 bits rq as norm quantizer (for fast scan)
60
73
  };
61
74
 
62
75
  AdditiveQuantizer(
@@ -69,16 +82,25 @@ struct AdditiveQuantizer {
69
82
  ///< compute derived values when d, M and nbits have been set
70
83
  void set_derived_values();
71
84
 
72
- ///< Train the additive quantizer
73
- virtual void train(size_t n, const float* x) = 0;
85
+ ///< Train the norm quantizer
86
+ void train_norm(size_t n, const float* norms);
87
+
88
+ void compute_codes(const float* x, uint8_t* codes, size_t n)
89
+ const override {
90
+ compute_codes_add_centroids(x, codes, n);
91
+ }
74
92
 
75
93
  /** Encode a set of vectors
76
94
  *
77
95
  * @param x vectors to encode, size n * d
78
96
  * @param codes output codes, size n * code_size
97
+ * @param centroids centroids to be added to x, size n * d
79
98
  */
80
- virtual void compute_codes(const float* x, uint8_t* codes, size_t n)
81
- const = 0;
99
+ virtual void compute_codes_add_centroids(
100
+ const float* x,
101
+ uint8_t* codes,
102
+ size_t n,
103
+ const float* centroids = nullptr) const = 0;
82
104
 
83
105
  /** pack a series of code to bit-compact format
84
106
  *
@@ -87,27 +109,29 @@ struct AdditiveQuantizer {
87
109
  * @param ld_codes leading dimension of codes
88
110
  * @param norms norms of the vectors (size n). Will be computed if
89
111
  * needed but not provided
112
+ * @param centroids centroids to be added to x, size n * d
90
113
  */
91
114
  void pack_codes(
92
115
  size_t n,
93
116
  const int32_t* codes,
94
117
  uint8_t* packed_codes,
95
118
  int64_t ld_codes = -1,
96
- const float* norms = nullptr) const;
119
+ const float* norms = nullptr,
120
+ const float* centroids = nullptr) const;
97
121
 
98
122
  /** Decode a set of vectors
99
123
  *
100
124
  * @param codes codes to decode, size n * code_size
101
125
  * @param x output vectors, size n * d
102
126
  */
103
- void decode(const uint8_t* codes, float* x, size_t n) const;
127
+ void decode(const uint8_t* codes, float* x, size_t n) const override;
104
128
 
105
129
  /** Decode a set of vectors in non-packed format
106
130
  *
107
131
  * @param codes codes to decode, size n * ld_codes
108
132
  * @param x output vectors, size n * d
109
133
  */
110
- void decode_unpacked(
134
+ virtual void decode_unpacked(
111
135
  const int32_t* codes,
112
136
  float* x,
113
137
  size_t n,
@@ -143,8 +167,15 @@ struct AdditiveQuantizer {
143
167
  *
144
168
  * @param xq query vector, size (n, d)
145
169
  * @param LUT look-up table, size (n, total_codebook_size)
170
+ * @param alpha compute alpha * inner-product
171
+ * @param ld_lut leading dimension of LUT
146
172
  */
147
- void compute_LUT(size_t n, const float* xq, float* LUT) const;
173
+ virtual void compute_LUT(
174
+ size_t n,
175
+ const float* xq,
176
+ float* LUT,
177
+ float alpha = 1.0f,
178
+ long ld_lut = -1) const;
148
179
 
149
180
  /// exact IP search
150
181
  void knn_centroids_inner_product(
@@ -199,60 +199,6 @@ void RangeSearchPartialResult::merge(
199
199
  result->lims[0] = 0;
200
200
  }
201
201
 
202
- /***********************************************************************
203
- * IDSelectorRange
204
- ***********************************************************************/
205
-
206
- IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax)
207
- : imin(imin), imax(imax) {}
208
-
209
- bool IDSelectorRange::is_member(idx_t id) const {
210
- return id >= imin && id < imax;
211
- }
212
-
213
- /***********************************************************************
214
- * IDSelectorArray
215
- ***********************************************************************/
216
-
217
- IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}
218
-
219
- bool IDSelectorArray::is_member(idx_t id) const {
220
- for (idx_t i = 0; i < n; i++) {
221
- if (ids[i] == id)
222
- return true;
223
- }
224
- return false;
225
- }
226
-
227
- /***********************************************************************
228
- * IDSelectorBatch
229
- ***********************************************************************/
230
-
231
- IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
232
- nbits = 0;
233
- while (n > (1L << nbits))
234
- nbits++;
235
- nbits += 5;
236
- // for n = 1M, nbits = 25 is optimal, see P56659518
237
-
238
- mask = (1L << nbits) - 1;
239
- bloom.resize(1UL << (nbits - 3), 0);
240
- for (long i = 0; i < n; i++) {
241
- Index::idx_t id = indices[i];
242
- set.insert(id);
243
- id &= mask;
244
- bloom[id >> 3] |= 1 << (id & 7);
245
- }
246
- }
247
-
248
- bool IDSelectorBatch::is_member(idx_t i) const {
249
- long im = i & mask;
250
- if (!(bloom[im >> 3] & (1 << (im & 7)))) {
251
- return 0;
252
- }
253
- return set.count(i);
254
- }
255
-
256
202
  /***********************************************************
257
203
  * Interrupt callback
258
204
  ***********************************************************/
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  // Auxiliary index structures, that are used in indexes but that can
11
9
  // be forward-declared
12
10
 
@@ -18,7 +16,6 @@
18
16
  #include <cstring>
19
17
  #include <memory>
20
18
  #include <mutex>
21
- #include <unordered_set>
22
19
  #include <vector>
23
20
 
24
21
  #include <faiss/Index.h>
@@ -52,55 +49,6 @@ struct RangeSearchResult {
52
49
  virtual ~RangeSearchResult();
53
50
  };
54
51
 
55
- /** Encapsulates a set of ids to remove. */
56
- struct IDSelector {
57
- typedef Index::idx_t idx_t;
58
- virtual bool is_member(idx_t id) const = 0;
59
- virtual ~IDSelector() {}
60
- };
61
-
62
- /** remove ids between [imni, imax) */
63
- struct IDSelectorRange : IDSelector {
64
- idx_t imin, imax;
65
-
66
- IDSelectorRange(idx_t imin, idx_t imax);
67
- bool is_member(idx_t id) const override;
68
- ~IDSelectorRange() override {}
69
- };
70
-
71
- /** simple list of elements to remove
72
- *
73
- * this is inefficient in most cases, except for IndexIVF with
74
- * maintain_direct_map
75
- */
76
- struct IDSelectorArray : IDSelector {
77
- size_t n;
78
- const idx_t* ids;
79
-
80
- IDSelectorArray(size_t n, const idx_t* ids);
81
- bool is_member(idx_t id) const override;
82
- ~IDSelectorArray() override {}
83
- };
84
-
85
- /** Remove ids from a set. Repetitions of ids in the indices set
86
- * passed to the constructor does not hurt performance. The hash
87
- * function used for the bloom filter and GCC's implementation of
88
- * unordered_set are just the least significant bits of the id. This
89
- * works fine for random ids or ids in sequences but will produce many
90
- * hash collisions if lsb's are always the same */
91
- struct IDSelectorBatch : IDSelector {
92
- std::unordered_set<idx_t> set;
93
-
94
- typedef unsigned char uint8_t;
95
- std::vector<uint8_t> bloom; // assumes low bits of id are a good hash value
96
- int nbits;
97
- idx_t mask;
98
-
99
- IDSelectorBatch(size_t n, const idx_t* indices);
100
- bool is_member(idx_t id) const override;
101
- ~IDSelectorBatch() override {}
102
- };
103
-
104
52
  /****************************************************************
105
53
  * Result structures for range search.
106
54
  *
@@ -186,30 +134,6 @@ struct RangeSearchPartialResult : BufferList {
186
134
  bool do_delete = true);
187
135
  };
188
136
 
189
- /***********************************************************
190
- * The distance computer maintains a current query and computes
191
- * distances to elements in an index that supports random access.
192
- *
193
- * The DistanceComputer is not intended to be thread-safe (eg. because
194
- * it maintains counters) so the distance functions are not const,
195
- * instantiate one from each thread if needed.
196
- ***********************************************************/
197
- struct DistanceComputer {
198
- using idx_t = Index::idx_t;
199
-
200
- /// called before computing distances. Pointer x should remain valid
201
- /// while operator () is called
202
- virtual void set_query(const float* x) = 0;
203
-
204
- /// compute distance of vector i to current query
205
- virtual float operator()(idx_t i) = 0;
206
-
207
- /// compute distance between two stored vectors
208
- virtual float symmetric_dis(idx_t i, idx_t j) = 0;
209
-
210
- virtual ~DistanceComputer() {}
211
- };
212
-
213
137
  /***********************************************************
214
138
  * Interrupt callback
215
139
  ***********************************************************/
@@ -0,0 +1,64 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+
12
+ namespace faiss {
13
+
14
+ /***********************************************************
15
+ * The distance computer maintains a current query and computes
16
+ * distances to elements in an index that supports random access.
17
+ *
18
+ * The DistanceComputer is not intended to be thread-safe (eg. because
19
+ * it maintains counters) so the distance functions are not const,
20
+ * instantiate one from each thread if needed.
21
+ *
22
+ * Note that the equivalent for IVF indexes is the InvertedListScanner,
23
+ * that has additional methods to handle the inverted list context.
24
+ ***********************************************************/
25
+ struct DistanceComputer {
26
+ using idx_t = Index::idx_t;
27
+
28
+ /// called before computing distances. Pointer x should remain valid
29
+ /// while operator () is called
30
+ virtual void set_query(const float* x) = 0;
31
+
32
+ /// compute distance of vector i to current query
33
+ virtual float operator()(idx_t i) = 0;
34
+
35
+ /// compute distance between two stored vectors
36
+ virtual float symmetric_dis(idx_t i, idx_t j) = 0;
37
+
38
+ virtual ~DistanceComputer() {}
39
+ };
40
+
41
+ /*************************************************************
42
+ * Specialized version of the DistanceComputer when we know that codes are
43
+ * laid out in a flat index.
44
+ */
45
+ struct FlatCodesDistanceComputer : DistanceComputer {
46
+ const uint8_t* codes;
47
+ size_t code_size;
48
+
49
+ FlatCodesDistanceComputer(const uint8_t* codes, size_t code_size)
50
+ : codes(codes), code_size(code_size) {}
51
+
52
+ FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}
53
+
54
+ float operator()(idx_t i) final {
55
+ return distance_to_code(codes + i * code_size);
56
+ }
57
+
58
+ /// compute distance of current query to an encoded vector
59
+ virtual float distance_to_code(const uint8_t* code) = 0;
60
+
61
+ virtual ~FlatCodesDistanceComputer() {}
62
+ };
63
+
64
+ } // namespace faiss