faiss 0.2.4 → 0.2.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -5,21 +5,15 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/impl/ResidualQuantizer.h>
11
9
 
12
10
  #include <algorithm>
11
+ #include <cmath>
13
12
  #include <cstddef>
14
13
  #include <cstdio>
15
14
  #include <cstring>
16
15
  #include <memory>
17
16
 
18
- #include <faiss/impl/FaissAssert.h>
19
- #include <faiss/impl/ResidualQuantizer.h>
20
- #include <faiss/utils/utils.h>
21
-
22
- #include <faiss/Clustering.h>
23
17
  #include <faiss/IndexFlat.h>
24
18
  #include <faiss/VectorTransform.h>
25
19
  #include <faiss/impl/AuxIndexStructures.h>
@@ -27,7 +21,6 @@
27
21
  #include <faiss/utils/Heap.h>
28
22
  #include <faiss/utils/distances.h>
29
23
  #include <faiss/utils/hamming.h>
30
- #include <faiss/utils/simdlib.h>
31
24
  #include <faiss/utils/utils.h>
32
25
 
33
26
  extern "C" {
@@ -47,15 +40,34 @@ int sgemm_(
47
40
  float* beta,
48
41
  float* c,
49
42
  FINTEGER* ldc);
43
+
44
+ // http://www.netlib.org/clapack/old/single/sgels.c
45
+ // solve least squares
46
+
47
+ int sgelsd_(
48
+ FINTEGER* m,
49
+ FINTEGER* n,
50
+ FINTEGER* nrhs,
51
+ float* a,
52
+ FINTEGER* lda,
53
+ float* b,
54
+ FINTEGER* ldb,
55
+ float* s,
56
+ float* rcond,
57
+ FINTEGER* rank,
58
+ float* work,
59
+ FINTEGER* lwork,
60
+ FINTEGER* iwork,
61
+ FINTEGER* info);
50
62
  }
51
63
 
52
64
  namespace faiss {
53
65
 
54
66
  ResidualQuantizer::ResidualQuantizer()
55
67
  : train_type(Train_progressive_dim),
68
+ niter_codebook_refine(5),
56
69
  max_beam_size(5),
57
70
  use_beam_LUT(0),
58
- max_mem_distances(5 * (size_t(1) << 30)), // 5 GiB
59
71
  assign_index_factory(nullptr) {
60
72
  d = 0;
61
73
  M = 0;
@@ -81,6 +93,39 @@ ResidualQuantizer::ResidualQuantizer(
81
93
  Search_type_t search_type)
82
94
  : ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {}
83
95
 
96
+ void ResidualQuantizer::initialize_from(
97
+ const ResidualQuantizer& other,
98
+ int skip_M) {
99
+ FAISS_THROW_IF_NOT(M + skip_M <= other.M);
100
+ FAISS_THROW_IF_NOT(skip_M >= 0);
101
+
102
+ Search_type_t this_search_type = search_type;
103
+ int this_M = M;
104
+
105
+ // a first good approximation: override everything
106
+ *this = other;
107
+
108
+ // adjust derived values
109
+ M = this_M;
110
+ search_type = this_search_type;
111
+ nbits.resize(M);
112
+ memcpy(nbits.data(),
113
+ other.nbits.data() + skip_M,
114
+ nbits.size() * sizeof(nbits[0]));
115
+
116
+ set_derived_values();
117
+
118
+ // resize codebooks if trained
119
+ if (codebooks.size() > 0) {
120
+ FAISS_THROW_IF_NOT(codebooks.size() == other.total_codebook_size * d);
121
+ codebooks.resize(total_codebook_size * d);
122
+ memcpy(codebooks.data(),
123
+ other.codebooks.data() + other.codebook_offsets[skip_M] * d,
124
+ codebooks.size() * sizeof(codebooks[0]));
125
+ // TODO: norm_tabs?
126
+ }
127
+ }
128
+
84
129
  void beam_search_encode_step(
85
130
  size_t d,
86
131
  size_t K,
@@ -245,8 +290,6 @@ void ResidualQuantizer::train(size_t n, const float* x) {
245
290
  }
246
291
  train_residuals = residuals1;
247
292
  }
248
- train_type_t tt = train_type_t(train_type & 1023);
249
-
250
293
  std::vector<float> codebooks;
251
294
  float obj = 0;
252
295
 
@@ -259,7 +302,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
259
302
 
260
303
  double t1 = getmillisecs();
261
304
 
262
- if (tt == Train_default) {
305
+ if (!(train_type & Train_progressive_dim)) { // regular kmeans
263
306
  Clustering clus(d, K, cp);
264
307
  clus.train(
265
308
  train_residuals.size() / d,
@@ -268,7 +311,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
268
311
  codebooks.swap(clus.centroids);
269
312
  assign_index->reset();
270
313
  obj = clus.iteration_stats.back().obj;
271
- } else if (tt == Train_progressive_dim) {
314
+ } else { // progressive dim clustering
272
315
  ProgressiveDimClustering clus(d, K, cp);
273
316
  ProgressiveDimIndexFactory default_fac;
274
317
  clus.train(
@@ -277,8 +320,6 @@ void ResidualQuantizer::train(size_t n, const float* x) {
277
320
  assign_index_factory ? *assign_index_factory : default_fac);
278
321
  codebooks.swap(clus.centroids);
279
322
  obj = clus.iteration_stats.back().obj;
280
- } else {
281
- FAISS_THROW_MSG("train type not supported");
282
323
  }
283
324
  clustering_time += (getmillisecs() - t1) / 1000;
284
325
 
@@ -350,6 +391,19 @@ void ResidualQuantizer::train(size_t n, const float* x) {
350
391
  cur_beam_size = new_beam_size;
351
392
  }
352
393
 
394
+ is_trained = true;
395
+
396
+ if (train_type & Train_refine_codebook) {
397
+ for (int iter = 0; iter < niter_codebook_refine; iter++) {
398
+ if (verbose) {
399
+ printf("re-estimating the codebooks to minimize "
400
+ "quantization errors (iter %d).\n",
401
+ iter);
402
+ }
403
+ retrain_AQ_codebook(n, x);
404
+ }
405
+ }
406
+
353
407
  // find min and max norms
354
408
  std::vector<float> norms(n);
355
409
 
@@ -359,33 +413,128 @@ void ResidualQuantizer::train(size_t n, const float* x) {
359
413
  }
360
414
 
361
415
  // fvec_norms_L2sqr(norms.data(), x, d, n);
416
+ train_norm(n, norms.data());
417
+
418
+ if (!(train_type & Skip_codebook_tables)) {
419
+ compute_codebook_tables();
420
+ }
421
+ }
422
+
423
+ float ResidualQuantizer::retrain_AQ_codebook(size_t n, const float* x) {
424
+ FAISS_THROW_IF_NOT_MSG(n >= total_codebook_size, "too few training points");
425
+
426
+ if (verbose) {
427
+ printf(" encoding %zd training vectors\n", n);
428
+ }
429
+ std::vector<uint8_t> codes(n * code_size);
430
+ compute_codes(x, codes.data(), n);
362
431
 
363
- norm_min = HUGE_VALF;
364
- norm_max = -HUGE_VALF;
365
- for (idx_t i = 0; i < n; i++) {
366
- if (norms[i] < norm_min) {
367
- norm_min = norms[i];
432
+ // compute reconstruction error
433
+ float input_recons_error;
434
+ {
435
+ std::vector<float> x_recons(n * d);
436
+ decode(codes.data(), x_recons.data(), n);
437
+ input_recons_error = fvec_L2sqr(x, x_recons.data(), n * d);
438
+ if (verbose) {
439
+ printf(" input quantization error %g\n", input_recons_error);
368
440
  }
369
- if (norms[i] > norm_max) {
370
- norm_max = norms[i];
441
+ }
442
+
443
+ // build matrix of the linear system
444
+ std::vector<float> C(n * total_codebook_size);
445
+ for (size_t i = 0; i < n; i++) {
446
+ BitstringReader bsr(codes.data() + i * code_size, code_size);
447
+ for (int m = 0; m < M; m++) {
448
+ int idx = bsr.read(nbits[m]);
449
+ C[i + (codebook_offsets[m] + idx) * n] = 1;
450
+ }
451
+ }
452
+
453
+ // transpose training vectors
454
+ std::vector<float> xt(n * d);
455
+
456
+ for (size_t i = 0; i < n; i++) {
457
+ for (size_t j = 0; j < d; j++) {
458
+ xt[j * n + i] = x[i * d + j];
371
459
  }
372
460
  }
373
461
 
374
- if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
375
- size_t k = (1 << 8);
376
- if (search_type == ST_norm_cqint4) {
377
- k = (1 << 4);
462
+ { // solve least squares
463
+ FINTEGER lwork = -1;
464
+ FINTEGER di = d, ni = n, tcsi = total_codebook_size;
465
+ FINTEGER info = -1, rank = -1;
466
+
467
+ float rcond = 1e-4; // this is an important parameter because the code
468
+ // matrix can be rank deficient for small problems,
469
+ // the default rcond=-1 does not work
470
+ float worksize;
471
+ std::vector<float> sing_vals(total_codebook_size);
472
+ FINTEGER nlvl = 1000; // formula is a bit convoluted so let's take an
473
+ // upper bound
474
+ std::vector<FINTEGER> iwork(
475
+ 3 * total_codebook_size * nlvl + 11 * total_codebook_size);
476
+
477
+ // worksize query
478
+ sgelsd_(&ni,
479
+ &tcsi,
480
+ &di,
481
+ C.data(),
482
+ &ni,
483
+ xt.data(),
484
+ &ni,
485
+ sing_vals.data(),
486
+ &rcond,
487
+ &rank,
488
+ &worksize,
489
+ &lwork,
490
+ iwork.data(),
491
+ &info);
492
+ FAISS_THROW_IF_NOT(info == 0);
493
+
494
+ lwork = worksize;
495
+ std::vector<float> work(lwork);
496
+ // actual call
497
+ sgelsd_(&ni,
498
+ &tcsi,
499
+ &di,
500
+ C.data(),
501
+ &ni,
502
+ xt.data(),
503
+ &ni,
504
+ sing_vals.data(),
505
+ &rcond,
506
+ &rank,
507
+ work.data(),
508
+ &lwork,
509
+ iwork.data(),
510
+ &info);
511
+ FAISS_THROW_IF_NOT_FMT(info == 0, "SGELS returned info=%d", int(info));
512
+ if (verbose) {
513
+ printf(" sgelsd rank=%d/%d\n",
514
+ int(rank),
515
+ int(total_codebook_size));
378
516
  }
379
- Clustering1D clus(k);
380
- clus.train_exact(n, norms.data());
381
- qnorm.add(clus.k, clus.centroids.data());
382
517
  }
383
518
 
384
- is_trained = true;
519
+ // result is in xt, re-transpose to codebook
385
520
 
386
- if (!(train_type & Skip_codebook_tables)) {
387
- compute_codebook_tables();
521
+ for (size_t i = 0; i < total_codebook_size; i++) {
522
+ for (size_t j = 0; j < d; j++) {
523
+ codebooks[i * d + j] = xt[j * n + i];
524
+ FAISS_THROW_IF_NOT(std::isfinite(codebooks[i * d + j]));
525
+ }
526
+ }
527
+
528
+ float output_recons_error = 0;
529
+ for (size_t j = 0; j < d; j++) {
530
+ output_recons_error += fvec_norm_L2sqr(
531
+ xt.data() + total_codebook_size + n * j,
532
+ n - total_codebook_size);
388
533
  }
534
+ if (verbose) {
535
+ printf(" output quantization error %g\n", output_recons_error);
536
+ }
537
+ return output_recons_error;
389
538
  }
390
539
 
391
540
  size_t ResidualQuantizer::memory_per_point(int beam_size) const {
@@ -400,10 +549,11 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
400
549
  return mem;
401
550
  }
402
551
 
403
- void ResidualQuantizer::compute_codes(
552
+ void ResidualQuantizer::compute_codes_add_centroids(
404
553
  const float* x,
405
554
  uint8_t* codes_out,
406
- size_t n) const {
555
+ size_t n,
556
+ const float* centroids) const {
407
557
  FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet.");
408
558
 
409
559
  size_t mem = memory_per_point();
@@ -415,7 +565,12 @@ void ResidualQuantizer::compute_codes(
415
565
  }
416
566
  for (size_t i0 = 0; i0 < n; i0 += bs) {
417
567
  size_t i1 = std::min(n, i0 + bs);
418
- compute_codes(x + i0 * d, codes_out + i0 * code_size, i1 - i0);
568
+ const float* cent = nullptr;
569
+ if (centroids != nullptr) {
570
+ cent = centroids + i0 * d;
571
+ }
572
+ compute_codes_add_centroids(
573
+ x + i0 * d, codes_out + i0 * code_size, i1 - i0, cent);
419
574
  }
420
575
  return;
421
576
  }
@@ -489,7 +644,8 @@ void ResidualQuantizer::compute_codes(
489
644
  codes.data(),
490
645
  codes_out,
491
646
  M * max_beam_size,
492
- norms.size() > 0 ? norms.data() : nullptr);
647
+ norms.size() > 0 ? norms.data() : nullptr,
648
+ centroids);
493
649
  }
494
650
 
495
651
  void ResidualQuantizer::refine_beam(
@@ -24,25 +24,31 @@ namespace faiss {
24
24
 
25
25
  struct ResidualQuantizer : AdditiveQuantizer {
26
26
  /// initialization
27
- enum train_type_t {
28
- Train_default = 0, ///< regular k-means
29
- Train_progressive_dim = 1, ///< progressive dim clustering
30
- Train_default_Train_top_beam = 1024,
31
- Train_progressive_dim_Train_top_beam = 1025,
32
- Train_default_Skip_codebook_tables = 2048,
33
- Train_progressive_dim_Skip_codebook_tables = 2049,
34
- Train_default_Train_top_beam_Skip_codebook_tables = 3072,
35
- Train_progressive_dim_Train_top_beam_Skip_codebook_tables = 3073,
36
- };
37
27
 
28
+ // Was enum but that does not work so well with bitmasks
29
+ using train_type_t = int;
30
+
31
+ /// Binary or of the Train_* flags below
38
32
  train_type_t train_type;
39
33
 
40
- // set this bit on train_type if beam is to be trained only on the
41
- // first element of the beam (faster but less accurate)
34
+ /// regular k-means (minimal amount of computation)
35
+ static const int Train_default = 0;
36
+
37
+ /// progressive dim clustering (set by default)
38
+ static const int Train_progressive_dim = 1;
39
+
40
+ /// do a few iterations of codebook refinement after first level estimation
41
+ static const int Train_refine_codebook = 2;
42
+
43
+ /// number of iterations for codebook refinement.
44
+ int niter_codebook_refine;
45
+
46
+ /** set this bit on train_type if beam is to be trained only on the
47
+ * first element of the beam (faster but less accurate) */
42
48
  static const int Train_top_beam = 1024;
43
49
 
44
- // set this bit to not autmatically compute the codebook tables
45
- // after training
50
+ /** set this bit to *not* autmatically compute the codebook tables
51
+ * after training */
46
52
  static const int Skip_codebook_tables = 2048;
47
53
 
48
54
  /// beam size used for training and for encoding
@@ -51,10 +57,6 @@ struct ResidualQuantizer : AdditiveQuantizer {
51
57
  /// use LUT for beam search
52
58
  int use_beam_LUT;
53
59
 
54
- /// distance matrixes with beam search can get large, so use this
55
- /// to batch computations at encoding time.
56
- size_t max_mem_distances;
57
-
58
60
  /// clustering parameters
59
61
  ProgressiveDimClusteringParameters cp;
60
62
 
@@ -74,15 +76,33 @@ struct ResidualQuantizer : AdditiveQuantizer {
74
76
 
75
77
  ResidualQuantizer();
76
78
 
77
- // Train the residual quantizer
79
+ /// Train the residual quantizer
78
80
  void train(size_t n, const float* x) override;
79
81
 
82
+ /// Copy the M codebook levels from other, starting from skip_M
83
+ void initialize_from(const ResidualQuantizer& other, int skip_M = 0);
84
+
85
+ /** Encode the vectors and compute codebook that minimizes the quantization
86
+ * error on these codes
87
+ *
88
+ * @param x training vectors, size n * d
89
+ * @param n nb of training vectors, n >= total_codebook_size
90
+ * @return returns quantization error for the new codebook with old
91
+ * codes
92
+ */
93
+ float retrain_AQ_codebook(size_t n, const float* x);
94
+
80
95
  /** Encode a set of vectors
81
96
  *
82
97
  * @param x vectors to encode, size n * d
83
98
  * @param codes output codes, size n * code_size
99
+ * @param centroids centroids to be added to x, size n * d
84
100
  */
85
- void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
101
+ void compute_codes_add_centroids(
102
+ const float* x,
103
+ uint8_t* codes,
104
+ size_t n,
105
+ const float* centroids = nullptr) const override;
86
106
 
87
107
  /** lower-level encode function
88
108
  *
@@ -413,4 +413,100 @@ struct RangeSearchResultHandler {
413
413
  }
414
414
  };
415
415
 
416
+ /*****************************************************************
417
+ * Single best result handler.
418
+ * Tracks the only best result, thus avoiding storing
419
+ * some temporary data in memory.
420
+ *****************************************************************/
421
+
422
+ template <class C>
423
+ struct SingleBestResultHandler {
424
+ using T = typename C::T;
425
+ using TI = typename C::TI;
426
+
427
+ int nq;
428
+ // contains exactly nq elements
429
+ T* dis_tab;
430
+ // contains exactly nq elements
431
+ TI* ids_tab;
432
+
433
+ SingleBestResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
434
+ : nq(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
435
+
436
+ struct SingleResultHandler {
437
+ SingleBestResultHandler& hr;
438
+
439
+ T min_dis;
440
+ TI min_idx;
441
+ size_t current_idx = 0;
442
+
443
+ SingleResultHandler(SingleBestResultHandler& hr) : hr(hr) {}
444
+
445
+ /// begin results for query # i
446
+ void begin(const size_t current_idx) {
447
+ this->current_idx = current_idx;
448
+ min_dis = HUGE_VALF;
449
+ min_idx = 0;
450
+ }
451
+
452
+ /// add one result for query i
453
+ void add_result(T dis, TI idx) {
454
+ if (C::cmp(min_dis, dis)) {
455
+ min_dis = dis;
456
+ min_idx = idx;
457
+ }
458
+ }
459
+
460
+ /// series of results for query i is done
461
+ void end() {
462
+ hr.dis_tab[current_idx] = min_dis;
463
+ hr.ids_tab[current_idx] = min_idx;
464
+ }
465
+ };
466
+
467
+ size_t i0, i1;
468
+
469
+ /// begin
470
+ void begin_multiple(size_t i0, size_t i1) {
471
+ this->i0 = i0;
472
+ this->i1 = i1;
473
+
474
+ for (size_t i = i0; i < i1; i++) {
475
+ this->dis_tab[i] = HUGE_VALF;
476
+ }
477
+ }
478
+
479
+ /// add results for query i0..i1 and j0..j1
480
+ void add_results(size_t j0, size_t j1, const T* dis_tab) {
481
+ for (int64_t i = i0; i < i1; i++) {
482
+ const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
483
+
484
+ auto& min_distance = this->dis_tab[i];
485
+ auto& min_index = this->ids_tab[i];
486
+
487
+ for (size_t j = j0; j < j1; j++) {
488
+ const T distance = dis_tab_i[j];
489
+
490
+ if (C::cmp(min_distance, distance)) {
491
+ min_distance = distance;
492
+ min_index = j;
493
+ }
494
+ }
495
+ }
496
+ }
497
+
498
+ void add_result(const size_t i, const T dis, const TI idx) {
499
+ auto& min_distance = this->dis_tab[i];
500
+ auto& min_index = this->ids_tab[i];
501
+
502
+ if (C::cmp(min_distance, dis)) {
503
+ min_distance = dis;
504
+ min_index = idx;
505
+ }
506
+ }
507
+
508
+ /// series of results for queries i0..i1 is done
509
+ void end_multiple() {}
510
+ };
511
+
416
512
  } // namespace faiss