faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -17,7 +17,10 @@
17
17
 
18
18
  #include <algorithm>
19
19
 
20
+ #include <faiss/Clustering.h>
20
21
  #include <faiss/impl/FaissAssert.h>
22
+ #include <faiss/impl/LocalSearchQuantizer.h>
23
+ #include <faiss/impl/ResidualQuantizer.h>
21
24
  #include <faiss/utils/Heap.h>
22
25
  #include <faiss/utils/distances.h>
23
26
  #include <faiss/utils/hamming.h>
@@ -48,14 +51,14 @@ AdditiveQuantizer::AdditiveQuantizer(
48
51
  size_t d,
49
52
  const std::vector<size_t>& nbits,
50
53
  Search_type_t search_type)
51
- : d(d),
54
+ : Quantizer(d),
52
55
  M(nbits.size()),
53
56
  nbits(nbits),
54
57
  verbose(false),
55
58
  is_trained(false),
59
+ max_mem_distances(5 * (size_t(1) << 30)), // 5 GiB
56
60
  search_type(search_type) {
57
61
  norm_max = norm_min = NAN;
58
- code_size = 0;
59
62
  tot_bits = 0;
60
63
  total_codebook_size = 0;
61
64
  only_8bit = false;
@@ -80,27 +83,82 @@ void AdditiveQuantizer::set_derived_values() {
80
83
  }
81
84
  total_codebook_size = codebook_offsets[M];
82
85
  switch (search_type) {
83
- case ST_decompress:
84
- case ST_LUT_nonorm:
85
- case ST_norm_from_LUT:
86
- break; // nothing to add
87
86
  case ST_norm_float:
88
- tot_bits += 32;
87
+ norm_bits = 32;
89
88
  break;
90
89
  case ST_norm_qint8:
91
90
  case ST_norm_cqint8:
92
- tot_bits += 8;
91
+ case ST_norm_lsq2x4:
92
+ case ST_norm_rq2x4:
93
+ norm_bits = 8;
93
94
  break;
94
95
  case ST_norm_qint4:
95
96
  case ST_norm_cqint4:
96
- tot_bits += 4;
97
+ norm_bits = 4;
98
+ break;
99
+ case ST_decompress:
100
+ case ST_LUT_nonorm:
101
+ case ST_norm_from_LUT:
102
+ default:
103
+ norm_bits = 0;
97
104
  break;
98
105
  }
106
+ tot_bits += norm_bits;
99
107
 
100
108
  // convert bits to bytes
101
109
  code_size = (tot_bits + 7) / 8;
102
110
  }
103
111
 
112
+ void AdditiveQuantizer::train_norm(size_t n, const float* norms) {
113
+ norm_min = HUGE_VALF;
114
+ norm_max = -HUGE_VALF;
115
+ for (idx_t i = 0; i < n; i++) {
116
+ if (norms[i] < norm_min) {
117
+ norm_min = norms[i];
118
+ }
119
+ if (norms[i] > norm_max) {
120
+ norm_max = norms[i];
121
+ }
122
+ }
123
+
124
+ if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
125
+ size_t k = (1 << 8);
126
+ if (search_type == ST_norm_cqint4) {
127
+ k = (1 << 4);
128
+ }
129
+ Clustering1D clus(k);
130
+ clus.train_exact(n, norms);
131
+ qnorm.add(clus.k, clus.centroids.data());
132
+ } else if (search_type == ST_norm_lsq2x4 || search_type == ST_norm_rq2x4) {
133
+ std::unique_ptr<AdditiveQuantizer> aq;
134
+ if (search_type == ST_norm_lsq2x4) {
135
+ aq.reset(new LocalSearchQuantizer(1, 2, 4));
136
+ } else {
137
+ aq.reset(new ResidualQuantizer(1, 2, 4));
138
+ }
139
+
140
+ aq->train(n, norms);
141
+ // flatten aq codebooks
142
+ std::vector<float> flat_codebooks(1 << 8);
143
+ FAISS_THROW_IF_NOT(aq->codebooks.size() == 32);
144
+
145
+ // save norm tables for 4-bit fastscan search
146
+ norm_tabs = aq->codebooks;
147
+
148
+ // assume big endian
149
+ const float* c = norm_tabs.data();
150
+ for (size_t i = 0; i < 16; i++) {
151
+ for (size_t j = 0; j < 16; j++) {
152
+ flat_codebooks[i * 16 + j] = c[j] + c[16 + i];
153
+ }
154
+ }
155
+
156
+ qnorm.reset();
157
+ qnorm.add(1 << 8, flat_codebooks.data());
158
+ FAISS_THROW_IF_NOT(qnorm.ntotal == (1 << 8));
159
+ }
160
+ }
161
+
104
162
  namespace {
105
163
 
106
164
  // TODO
@@ -132,7 +190,7 @@ float decode_qint4(uint8_t i, float amin, float amax) {
132
190
 
133
191
  uint32_t AdditiveQuantizer::encode_qcint(float x) const {
134
192
  idx_t id;
135
- qnorm.assign(idx_t(1), &x, &id, idx_t(1));
193
+ qnorm.assign(1, &x, &id, 1);
136
194
  return uint32_t(id);
137
195
  }
138
196
 
@@ -140,23 +198,54 @@ float AdditiveQuantizer::decode_qcint(uint32_t c) const {
140
198
  return qnorm.get_xb()[c];
141
199
  }
142
200
 
201
+ uint64_t AdditiveQuantizer::encode_norm(float norm) const {
202
+ switch (search_type) {
203
+ case ST_norm_float:
204
+ uint32_t inorm;
205
+ memcpy(&inorm, &norm, 4);
206
+ return inorm;
207
+ case ST_norm_qint8:
208
+ return encode_qint8(norm, norm_min, norm_max);
209
+ case ST_norm_qint4:
210
+ return encode_qint4(norm, norm_min, norm_max);
211
+ case ST_norm_lsq2x4:
212
+ case ST_norm_rq2x4:
213
+ case ST_norm_cqint8:
214
+ return encode_qcint(norm);
215
+ case ST_norm_cqint4:
216
+ return encode_qcint(norm);
217
+ case ST_decompress:
218
+ case ST_LUT_nonorm:
219
+ case ST_norm_from_LUT:
220
+ default:
221
+ return 0;
222
+ }
223
+ }
224
+
143
225
  void AdditiveQuantizer::pack_codes(
144
226
  size_t n,
145
227
  const int32_t* codes,
146
228
  uint8_t* packed_codes,
147
229
  int64_t ld_codes,
148
- const float* norms) const {
230
+ const float* norms,
231
+ const float* centroids) const {
149
232
  if (ld_codes == -1) {
150
233
  ld_codes = M;
151
234
  }
152
235
  std::vector<float> norm_buf;
153
236
  if (search_type == ST_norm_float || search_type == ST_norm_qint4 ||
154
237
  search_type == ST_norm_qint8 || search_type == ST_norm_cqint8 ||
155
- search_type == ST_norm_cqint4) {
156
- if (!norms) {
238
+ search_type == ST_norm_cqint4 || search_type == ST_norm_lsq2x4 ||
239
+ search_type == ST_norm_rq2x4) {
240
+ if (centroids != nullptr || !norms) {
157
241
  norm_buf.resize(n);
158
242
  std::vector<float> x_recons(n * d);
159
243
  decode_unpacked(codes, x_recons.data(), n, ld_codes);
244
+
245
+ if (centroids != nullptr) {
246
+ // x = x + c
247
+ fvec_add(n * d, x_recons.data(), centroids, x_recons.data());
248
+ }
160
249
  fvec_norms_L2sqr(norm_buf.data(), x_recons.data(), d, n);
161
250
  norms = norm_buf.data();
162
251
  }
@@ -168,34 +257,8 @@ void AdditiveQuantizer::pack_codes(
168
257
  for (int m = 0; m < M; m++) {
169
258
  bsw.write(codes1[m], nbits[m]);
170
259
  }
171
- switch (search_type) {
172
- case ST_decompress:
173
- case ST_LUT_nonorm:
174
- case ST_norm_from_LUT:
175
- break;
176
- case ST_norm_float:
177
- bsw.write(*(uint32_t*)&norms[i], 32);
178
- break;
179
- case ST_norm_qint8: {
180
- uint8_t b = encode_qint8(norms[i], norm_min, norm_max);
181
- bsw.write(b, 8);
182
- break;
183
- }
184
- case ST_norm_qint4: {
185
- uint8_t b = encode_qint4(norms[i], norm_min, norm_max);
186
- bsw.write(b, 4);
187
- break;
188
- }
189
- case ST_norm_cqint8: {
190
- uint32_t b = encode_qcint(norms[i]);
191
- bsw.write(b, 8);
192
- break;
193
- }
194
- case ST_norm_cqint4: {
195
- uint32_t b = encode_qcint(norms[i]);
196
- bsw.write(b, 4);
197
- break;
198
- }
260
+ if (norm_bits != 0) {
261
+ bsw.write(encode_norm(norms[i]), norm_bits);
199
262
  }
200
263
  }
201
264
  }
@@ -283,28 +346,33 @@ void AdditiveQuantizer::decode_64bit(idx_t bits, float* xi) const {
283
346
  }
284
347
  }
285
348
 
286
- void AdditiveQuantizer::compute_LUT(size_t n, const float* xq, float* LUT)
287
- const {
349
+ void AdditiveQuantizer::compute_LUT(
350
+ size_t n,
351
+ const float* xq,
352
+ float* LUT,
353
+ float alpha,
354
+ long ld_lut) const {
288
355
  // in all cases, it is large matrix multiplication
289
356
 
290
357
  FINTEGER ncenti = total_codebook_size;
291
358
  FINTEGER di = d;
292
359
  FINTEGER nqi = n;
293
- float one = 1, zero = 0;
360
+ FINTEGER ldc = ld_lut > 0 ? ld_lut : ncenti;
361
+ float zero = 0;
294
362
 
295
363
  sgemm_("Transposed",
296
364
  "Not transposed",
297
365
  &ncenti,
298
366
  &nqi,
299
367
  &di,
300
- &one,
368
+ &alpha,
301
369
  codebooks.data(),
302
370
  &di,
303
371
  xq,
304
372
  &di,
305
373
  &zero,
306
374
  LUT,
307
- &ncenti);
375
+ &ldc);
308
376
  }
309
377
 
310
378
  namespace {
@@ -448,7 +516,8 @@ float AdditiveQuantizer::
448
516
  BitstringReader bs(codes, code_size);
449
517
  float accu = accumulate_IPs(*this, bs, codes, LUT);
450
518
  uint32_t norm_i = bs.read(32);
451
- float norm2 = *(float*)&norm_i;
519
+ float norm2;
520
+ memcpy(&norm2, &norm_i, 4);
452
521
  return norm2 - 2 * accu;
453
522
  }
454
523
 
@@ -12,6 +12,7 @@
12
12
 
13
13
  #include <faiss/Index.h>
14
14
  #include <faiss/IndexFlat.h>
15
+ #include <faiss/impl/Quantizer.h>
15
16
 
16
17
  namespace faiss {
17
18
 
@@ -21,23 +22,31 @@ namespace faiss {
21
22
  * concatenation of M sub-vectors, additive quantizers sum M sub-vectors
22
23
  * to get the decoded vector.
23
24
  */
24
- struct AdditiveQuantizer {
25
- size_t d; ///< size of the input vectors
25
+ struct AdditiveQuantizer : Quantizer {
26
26
  size_t M; ///< number of codebooks
27
27
  std::vector<size_t> nbits; ///< bits for each step
28
28
  std::vector<float> codebooks; ///< codebooks
29
29
 
30
30
  // derived values
31
31
  std::vector<uint64_t> codebook_offsets;
32
- size_t code_size; ///< code size in bytes
33
- size_t tot_bits; ///< total number of bits
32
+ size_t tot_bits; ///< total number of bits (indexes + norms)
33
+ size_t norm_bits; ///< bits allocated for the norms
34
34
  size_t total_codebook_size; ///< size of the codebook in vectors
35
35
  bool only_8bit; ///< are all nbits = 8 (use faster decoder)
36
36
 
37
37
  bool verbose; ///< verbose during training?
38
38
  bool is_trained; ///< is trained or not
39
39
 
40
- IndexFlat1D qnorm; ///< store and search norms
40
+ IndexFlat1D qnorm; ///< store and search norms
41
+ std::vector<float> norm_tabs; ///< store norms of codebook entries for 4-bit
42
+ ///< fastscan search
43
+
44
+ /// norms and distance matrixes with beam search can get large, so use this
45
+ /// to control for the amount of memory that can be allocated
46
+ size_t max_mem_distances;
47
+
48
+ /// encode a norm into norm_bits bits
49
+ uint64_t encode_norm(float norm) const;
41
50
 
42
51
  uint32_t encode_qcint(
43
52
  float x) const; ///< encode norm by non-uniform scalar quantization
@@ -57,6 +66,10 @@ struct AdditiveQuantizer {
57
66
  ST_norm_qint4,
58
67
  ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm
59
68
  ST_norm_cqint4,
69
+
70
+ ST_norm_lsq2x4, ///< use a 2x4 bits lsq as norm quantizer (for fast
71
+ ///< scan)
72
+ ST_norm_rq2x4, ///< use a 2x4 bits rq as norm quantizer (for fast scan)
60
73
  };
61
74
 
62
75
  AdditiveQuantizer(
@@ -69,16 +82,25 @@ struct AdditiveQuantizer {
69
82
  ///< compute derived values when d, M and nbits have been set
70
83
  void set_derived_values();
71
84
 
72
- ///< Train the additive quantizer
73
- virtual void train(size_t n, const float* x) = 0;
85
+ ///< Train the norm quantizer
86
+ void train_norm(size_t n, const float* norms);
87
+
88
+ void compute_codes(const float* x, uint8_t* codes, size_t n)
89
+ const override {
90
+ compute_codes_add_centroids(x, codes, n);
91
+ }
74
92
 
75
93
  /** Encode a set of vectors
76
94
  *
77
95
  * @param x vectors to encode, size n * d
78
96
  * @param codes output codes, size n * code_size
97
+ * @param centroids centroids to be added to x, size n * d
79
98
  */
80
- virtual void compute_codes(const float* x, uint8_t* codes, size_t n)
81
- const = 0;
99
+ virtual void compute_codes_add_centroids(
100
+ const float* x,
101
+ uint8_t* codes,
102
+ size_t n,
103
+ const float* centroids = nullptr) const = 0;
82
104
 
83
105
  /** pack a series of code to bit-compact format
84
106
  *
@@ -87,27 +109,29 @@ struct AdditiveQuantizer {
87
109
  * @param ld_codes leading dimension of codes
88
110
  * @param norms norms of the vectors (size n). Will be computed if
89
111
  * needed but not provided
112
+ * @param centroids centroids to be added to x, size n * d
90
113
  */
91
114
  void pack_codes(
92
115
  size_t n,
93
116
  const int32_t* codes,
94
117
  uint8_t* packed_codes,
95
118
  int64_t ld_codes = -1,
96
- const float* norms = nullptr) const;
119
+ const float* norms = nullptr,
120
+ const float* centroids = nullptr) const;
97
121
 
98
122
  /** Decode a set of vectors
99
123
  *
100
124
  * @param codes codes to decode, size n * code_size
101
125
  * @param x output vectors, size n * d
102
126
  */
103
- void decode(const uint8_t* codes, float* x, size_t n) const;
127
+ void decode(const uint8_t* codes, float* x, size_t n) const override;
104
128
 
105
129
  /** Decode a set of vectors in non-packed format
106
130
  *
107
131
  * @param codes codes to decode, size n * ld_codes
108
132
  * @param x output vectors, size n * d
109
133
  */
110
- void decode_unpacked(
134
+ virtual void decode_unpacked(
111
135
  const int32_t* codes,
112
136
  float* x,
113
137
  size_t n,
@@ -143,8 +167,15 @@ struct AdditiveQuantizer {
143
167
  *
144
168
  * @param xq query vector, size (n, d)
145
169
  * @param LUT look-up table, size (n, total_codebook_size)
170
+ * @param alpha compute alpha * inner-product
171
+ * @param ld_lut leading dimension of LUT
146
172
  */
147
- void compute_LUT(size_t n, const float* xq, float* LUT) const;
173
+ virtual void compute_LUT(
174
+ size_t n,
175
+ const float* xq,
176
+ float* LUT,
177
+ float alpha = 1.0f,
178
+ long ld_lut = -1) const;
148
179
 
149
180
  /// exact IP search
150
181
  void knn_centroids_inner_product(
@@ -199,60 +199,6 @@ void RangeSearchPartialResult::merge(
199
199
  result->lims[0] = 0;
200
200
  }
201
201
 
202
- /***********************************************************************
203
- * IDSelectorRange
204
- ***********************************************************************/
205
-
206
- IDSelectorRange::IDSelectorRange(idx_t imin, idx_t imax)
207
- : imin(imin), imax(imax) {}
208
-
209
- bool IDSelectorRange::is_member(idx_t id) const {
210
- return id >= imin && id < imax;
211
- }
212
-
213
- /***********************************************************************
214
- * IDSelectorArray
215
- ***********************************************************************/
216
-
217
- IDSelectorArray::IDSelectorArray(size_t n, const idx_t* ids) : n(n), ids(ids) {}
218
-
219
- bool IDSelectorArray::is_member(idx_t id) const {
220
- for (idx_t i = 0; i < n; i++) {
221
- if (ids[i] == id)
222
- return true;
223
- }
224
- return false;
225
- }
226
-
227
- /***********************************************************************
228
- * IDSelectorBatch
229
- ***********************************************************************/
230
-
231
- IDSelectorBatch::IDSelectorBatch(size_t n, const idx_t* indices) {
232
- nbits = 0;
233
- while (n > (1L << nbits))
234
- nbits++;
235
- nbits += 5;
236
- // for n = 1M, nbits = 25 is optimal, see P56659518
237
-
238
- mask = (1L << nbits) - 1;
239
- bloom.resize(1UL << (nbits - 3), 0);
240
- for (long i = 0; i < n; i++) {
241
- Index::idx_t id = indices[i];
242
- set.insert(id);
243
- id &= mask;
244
- bloom[id >> 3] |= 1 << (id & 7);
245
- }
246
- }
247
-
248
- bool IDSelectorBatch::is_member(idx_t i) const {
249
- long im = i & mask;
250
- if (!(bloom[im >> 3] & (1 << (im & 7)))) {
251
- return 0;
252
- }
253
- return set.count(i);
254
- }
255
-
256
202
  /***********************************************************
257
203
  * Interrupt callback
258
204
  ***********************************************************/
@@ -5,8 +5,6 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  // Auxiliary index structures, that are used in indexes but that can
11
9
  // be forward-declared
12
10
 
@@ -18,7 +16,6 @@
18
16
  #include <cstring>
19
17
  #include <memory>
20
18
  #include <mutex>
21
- #include <unordered_set>
22
19
  #include <vector>
23
20
 
24
21
  #include <faiss/Index.h>
@@ -52,55 +49,6 @@ struct RangeSearchResult {
52
49
  virtual ~RangeSearchResult();
53
50
  };
54
51
 
55
- /** Encapsulates a set of ids to remove. */
56
- struct IDSelector {
57
- typedef Index::idx_t idx_t;
58
- virtual bool is_member(idx_t id) const = 0;
59
- virtual ~IDSelector() {}
60
- };
61
-
62
- /** remove ids between [imni, imax) */
63
- struct IDSelectorRange : IDSelector {
64
- idx_t imin, imax;
65
-
66
- IDSelectorRange(idx_t imin, idx_t imax);
67
- bool is_member(idx_t id) const override;
68
- ~IDSelectorRange() override {}
69
- };
70
-
71
- /** simple list of elements to remove
72
- *
73
- * this is inefficient in most cases, except for IndexIVF with
74
- * maintain_direct_map
75
- */
76
- struct IDSelectorArray : IDSelector {
77
- size_t n;
78
- const idx_t* ids;
79
-
80
- IDSelectorArray(size_t n, const idx_t* ids);
81
- bool is_member(idx_t id) const override;
82
- ~IDSelectorArray() override {}
83
- };
84
-
85
- /** Remove ids from a set. Repetitions of ids in the indices set
86
- * passed to the constructor does not hurt performance. The hash
87
- * function used for the bloom filter and GCC's implementation of
88
- * unordered_set are just the least significant bits of the id. This
89
- * works fine for random ids or ids in sequences but will produce many
90
- * hash collisions if lsb's are always the same */
91
- struct IDSelectorBatch : IDSelector {
92
- std::unordered_set<idx_t> set;
93
-
94
- typedef unsigned char uint8_t;
95
- std::vector<uint8_t> bloom; // assumes low bits of id are a good hash value
96
- int nbits;
97
- idx_t mask;
98
-
99
- IDSelectorBatch(size_t n, const idx_t* indices);
100
- bool is_member(idx_t id) const override;
101
- ~IDSelectorBatch() override {}
102
- };
103
-
104
52
  /****************************************************************
105
53
  * Result structures for range search.
106
54
  *
@@ -186,30 +134,6 @@ struct RangeSearchPartialResult : BufferList {
186
134
  bool do_delete = true);
187
135
  };
188
136
 
189
- /***********************************************************
190
- * The distance computer maintains a current query and computes
191
- * distances to elements in an index that supports random access.
192
- *
193
- * The DistanceComputer is not intended to be thread-safe (eg. because
194
- * it maintains counters) so the distance functions are not const,
195
- * instantiate one from each thread if needed.
196
- ***********************************************************/
197
- struct DistanceComputer {
198
- using idx_t = Index::idx_t;
199
-
200
- /// called before computing distances. Pointer x should remain valid
201
- /// while operator () is called
202
- virtual void set_query(const float* x) = 0;
203
-
204
- /// compute distance of vector i to current query
205
- virtual float operator()(idx_t i) = 0;
206
-
207
- /// compute distance between two stored vectors
208
- virtual float symmetric_dis(idx_t i, idx_t j) = 0;
209
-
210
- virtual ~DistanceComputer() {}
211
- };
212
-
213
137
  /***********************************************************
214
138
  * Interrupt callback
215
139
  ***********************************************************/
@@ -0,0 +1,64 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #pragma once
9
+
10
+ #include <faiss/Index.h>
11
+
12
+ namespace faiss {
13
+
14
+ /***********************************************************
15
+ * The distance computer maintains a current query and computes
16
+ * distances to elements in an index that supports random access.
17
+ *
18
+ * The DistanceComputer is not intended to be thread-safe (eg. because
19
+ * it maintains counters) so the distance functions are not const,
20
+ * instantiate one from each thread if needed.
21
+ *
22
+ * Note that the equivalent for IVF indexes is the InvertedListScanner,
23
+ * that has additional methods to handle the inverted list context.
24
+ ***********************************************************/
25
+ struct DistanceComputer {
26
+ using idx_t = Index::idx_t;
27
+
28
+ /// called before computing distances. Pointer x should remain valid
29
+ /// while operator () is called
30
+ virtual void set_query(const float* x) = 0;
31
+
32
+ /// compute distance of vector i to current query
33
+ virtual float operator()(idx_t i) = 0;
34
+
35
+ /// compute distance between two stored vectors
36
+ virtual float symmetric_dis(idx_t i, idx_t j) = 0;
37
+
38
+ virtual ~DistanceComputer() {}
39
+ };
40
+
41
+ /*************************************************************
42
+ * Specialized version of the DistanceComputer when we know that codes are
43
+ * laid out in a flat index.
44
+ */
45
+ struct FlatCodesDistanceComputer : DistanceComputer {
46
+ const uint8_t* codes;
47
+ size_t code_size;
48
+
49
+ FlatCodesDistanceComputer(const uint8_t* codes, size_t code_size)
50
+ : codes(codes), code_size(code_size) {}
51
+
52
+ FlatCodesDistanceComputer() : codes(nullptr), code_size(0) {}
53
+
54
+ float operator()(idx_t i) final {
55
+ return distance_to_code(codes + i * code_size);
56
+ }
57
+
58
+ /// compute distance of current query to an encoded vector
59
+ virtual float distance_to_code(const uint8_t* code) = 0;
60
+
61
+ virtual ~FlatCodesDistanceComputer() {}
62
+ };
63
+
64
+ } // namespace faiss