faiss 0.2.4 → 0.2.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (177) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/README.md +23 -21
  4. data/ext/faiss/extconf.rb +11 -0
  5. data/ext/faiss/index.cpp +4 -4
  6. data/ext/faiss/index_binary.cpp +6 -6
  7. data/ext/faiss/product_quantizer.cpp +4 -4
  8. data/lib/faiss/version.rb +1 -1
  9. data/vendor/faiss/faiss/AutoTune.cpp +13 -0
  10. data/vendor/faiss/faiss/IVFlib.cpp +101 -2
  11. data/vendor/faiss/faiss/IVFlib.h +26 -2
  12. data/vendor/faiss/faiss/Index.cpp +36 -3
  13. data/vendor/faiss/faiss/Index.h +43 -6
  14. data/vendor/faiss/faiss/Index2Layer.cpp +6 -2
  15. data/vendor/faiss/faiss/Index2Layer.h +6 -1
  16. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +219 -16
  17. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +63 -5
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
  20. data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +18 -3
  22. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
  23. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
  24. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
  25. data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
  26. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
  27. data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
  28. data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
  29. data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
  30. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
  31. data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
  32. data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
  33. data/vendor/faiss/faiss/IndexFastScan.h +145 -0
  34. data/vendor/faiss/faiss/IndexFlat.cpp +34 -21
  35. data/vendor/faiss/faiss/IndexFlat.h +7 -4
  36. data/vendor/faiss/faiss/IndexFlatCodes.cpp +35 -1
  37. data/vendor/faiss/faiss/IndexFlatCodes.h +12 -0
  38. data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
  39. data/vendor/faiss/faiss/IndexHNSW.h +4 -2
  40. data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
  41. data/vendor/faiss/faiss/IndexIDMap.h +107 -0
  42. data/vendor/faiss/faiss/IndexIVF.cpp +121 -33
  43. data/vendor/faiss/faiss/IndexIVF.h +35 -16
  44. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +84 -7
  45. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +63 -1
  46. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
  47. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
  48. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
  49. data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
  50. data/vendor/faiss/faiss/IndexIVFFlat.cpp +37 -17
  51. data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
  52. data/vendor/faiss/faiss/IndexIVFPQ.cpp +234 -50
  53. data/vendor/faiss/faiss/IndexIVFPQ.h +5 -1
  54. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
  55. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
  56. data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
  57. data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
  58. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +3 -1
  59. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +2 -1
  60. data/vendor/faiss/faiss/IndexLSH.cpp +4 -2
  61. data/vendor/faiss/faiss/IndexLSH.h +2 -1
  62. data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
  63. data/vendor/faiss/faiss/IndexLattice.h +3 -1
  64. data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -3
  65. data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
  66. data/vendor/faiss/faiss/IndexNSG.cpp +37 -3
  67. data/vendor/faiss/faiss/IndexNSG.h +25 -1
  68. data/vendor/faiss/faiss/IndexPQ.cpp +106 -69
  69. data/vendor/faiss/faiss/IndexPQ.h +19 -5
  70. data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
  71. data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
  72. data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
  73. data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
  74. data/vendor/faiss/faiss/IndexRefine.cpp +8 -4
  75. data/vendor/faiss/faiss/IndexRefine.h +4 -2
  76. data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
  77. data/vendor/faiss/faiss/IndexReplicas.h +2 -1
  78. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
  79. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
  80. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +26 -15
  81. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -7
  82. data/vendor/faiss/faiss/IndexShards.cpp +4 -1
  83. data/vendor/faiss/faiss/IndexShards.h +2 -1
  84. data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
  85. data/vendor/faiss/faiss/MetaIndexes.h +3 -81
  86. data/vendor/faiss/faiss/VectorTransform.cpp +43 -0
  87. data/vendor/faiss/faiss/VectorTransform.h +22 -4
  88. data/vendor/faiss/faiss/clone_index.cpp +23 -1
  89. data/vendor/faiss/faiss/clone_index.h +3 -0
  90. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
  91. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
  92. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
  93. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
  94. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
  95. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
  96. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
  97. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
  98. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
  99. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
  100. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
  101. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
  102. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +0 -4
  103. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  104. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
  105. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
  106. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
  107. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
  108. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
  109. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
  110. data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
  111. data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
  112. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
  113. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
  114. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
  115. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
  116. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
  117. data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
  118. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
  119. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +116 -47
  121. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +44 -13
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
  123. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
  124. data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
  125. data/vendor/faiss/faiss/impl/HNSW.cpp +123 -27
  126. data/vendor/faiss/faiss/impl/HNSW.h +19 -16
  127. data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
  128. data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +6 -28
  130. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +6 -1
  131. data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
  132. data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
  133. data/vendor/faiss/faiss/impl/NSG.cpp +1 -1
  134. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
  138. data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +192 -36
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +40 -20
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +97 -173
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
  144. data/vendor/faiss/faiss/impl/index_read.cpp +240 -9
  145. data/vendor/faiss/faiss/impl/index_write.cpp +237 -5
  146. data/vendor/faiss/faiss/impl/kmeans1d.cpp +6 -4
  147. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
  148. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
  149. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
  150. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
  151. data/vendor/faiss/faiss/index_factory.cpp +196 -7
  152. data/vendor/faiss/faiss/index_io.h +5 -0
  153. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
  154. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
  155. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
  156. data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
  157. data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
  158. data/vendor/faiss/faiss/utils/Heap.h +31 -15
  159. data/vendor/faiss/faiss/utils/distances.cpp +380 -56
  160. data/vendor/faiss/faiss/utils/distances.h +113 -15
  161. data/vendor/faiss/faiss/utils/distances_simd.cpp +726 -6
  162. data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
  163. data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
  164. data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
  165. data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
  166. data/vendor/faiss/faiss/utils/fp16.h +11 -0
  167. data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
  168. data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
  169. data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
  170. data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
  171. data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
  172. data/vendor/faiss/faiss/utils/random.cpp +53 -0
  173. data/vendor/faiss/faiss/utils/random.h +5 -0
  174. data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
  175. data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
  176. data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
  177. metadata +37 -3
@@ -5,21 +5,15 @@
5
5
  * LICENSE file in the root directory of this source tree.
6
6
  */
7
7
 
8
- // -*- c++ -*-
9
-
10
8
  #include <faiss/impl/ResidualQuantizer.h>
11
9
 
12
10
  #include <algorithm>
11
+ #include <cmath>
13
12
  #include <cstddef>
14
13
  #include <cstdio>
15
14
  #include <cstring>
16
15
  #include <memory>
17
16
 
18
- #include <faiss/impl/FaissAssert.h>
19
- #include <faiss/impl/ResidualQuantizer.h>
20
- #include <faiss/utils/utils.h>
21
-
22
- #include <faiss/Clustering.h>
23
17
  #include <faiss/IndexFlat.h>
24
18
  #include <faiss/VectorTransform.h>
25
19
  #include <faiss/impl/AuxIndexStructures.h>
@@ -27,7 +21,6 @@
27
21
  #include <faiss/utils/Heap.h>
28
22
  #include <faiss/utils/distances.h>
29
23
  #include <faiss/utils/hamming.h>
30
- #include <faiss/utils/simdlib.h>
31
24
  #include <faiss/utils/utils.h>
32
25
 
33
26
  extern "C" {
@@ -47,15 +40,34 @@ int sgemm_(
47
40
  float* beta,
48
41
  float* c,
49
42
  FINTEGER* ldc);
43
+
44
+ // http://www.netlib.org/clapack/old/single/sgels.c
45
+ // solve least squares
46
+
47
+ int sgelsd_(
48
+ FINTEGER* m,
49
+ FINTEGER* n,
50
+ FINTEGER* nrhs,
51
+ float* a,
52
+ FINTEGER* lda,
53
+ float* b,
54
+ FINTEGER* ldb,
55
+ float* s,
56
+ float* rcond,
57
+ FINTEGER* rank,
58
+ float* work,
59
+ FINTEGER* lwork,
60
+ FINTEGER* iwork,
61
+ FINTEGER* info);
50
62
  }
51
63
 
52
64
  namespace faiss {
53
65
 
54
66
  ResidualQuantizer::ResidualQuantizer()
55
67
  : train_type(Train_progressive_dim),
68
+ niter_codebook_refine(5),
56
69
  max_beam_size(5),
57
70
  use_beam_LUT(0),
58
- max_mem_distances(5 * (size_t(1) << 30)), // 5 GiB
59
71
  assign_index_factory(nullptr) {
60
72
  d = 0;
61
73
  M = 0;
@@ -81,6 +93,39 @@ ResidualQuantizer::ResidualQuantizer(
81
93
  Search_type_t search_type)
82
94
  : ResidualQuantizer(d, std::vector<size_t>(M, nbits), search_type) {}
83
95
 
96
+ void ResidualQuantizer::initialize_from(
97
+ const ResidualQuantizer& other,
98
+ int skip_M) {
99
+ FAISS_THROW_IF_NOT(M + skip_M <= other.M);
100
+ FAISS_THROW_IF_NOT(skip_M >= 0);
101
+
102
+ Search_type_t this_search_type = search_type;
103
+ int this_M = M;
104
+
105
+ // a first good approximation: override everything
106
+ *this = other;
107
+
108
+ // adjust derived values
109
+ M = this_M;
110
+ search_type = this_search_type;
111
+ nbits.resize(M);
112
+ memcpy(nbits.data(),
113
+ other.nbits.data() + skip_M,
114
+ nbits.size() * sizeof(nbits[0]));
115
+
116
+ set_derived_values();
117
+
118
+ // resize codebooks if trained
119
+ if (codebooks.size() > 0) {
120
+ FAISS_THROW_IF_NOT(codebooks.size() == other.total_codebook_size * d);
121
+ codebooks.resize(total_codebook_size * d);
122
+ memcpy(codebooks.data(),
123
+ other.codebooks.data() + other.codebook_offsets[skip_M] * d,
124
+ codebooks.size() * sizeof(codebooks[0]));
125
+ // TODO: norm_tabs?
126
+ }
127
+ }
128
+
84
129
  void beam_search_encode_step(
85
130
  size_t d,
86
131
  size_t K,
@@ -245,8 +290,6 @@ void ResidualQuantizer::train(size_t n, const float* x) {
245
290
  }
246
291
  train_residuals = residuals1;
247
292
  }
248
- train_type_t tt = train_type_t(train_type & 1023);
249
-
250
293
  std::vector<float> codebooks;
251
294
  float obj = 0;
252
295
 
@@ -259,7 +302,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
259
302
 
260
303
  double t1 = getmillisecs();
261
304
 
262
- if (tt == Train_default) {
305
+ if (!(train_type & Train_progressive_dim)) { // regular kmeans
263
306
  Clustering clus(d, K, cp);
264
307
  clus.train(
265
308
  train_residuals.size() / d,
@@ -268,7 +311,7 @@ void ResidualQuantizer::train(size_t n, const float* x) {
268
311
  codebooks.swap(clus.centroids);
269
312
  assign_index->reset();
270
313
  obj = clus.iteration_stats.back().obj;
271
- } else if (tt == Train_progressive_dim) {
314
+ } else { // progressive dim clustering
272
315
  ProgressiveDimClustering clus(d, K, cp);
273
316
  ProgressiveDimIndexFactory default_fac;
274
317
  clus.train(
@@ -277,8 +320,6 @@ void ResidualQuantizer::train(size_t n, const float* x) {
277
320
  assign_index_factory ? *assign_index_factory : default_fac);
278
321
  codebooks.swap(clus.centroids);
279
322
  obj = clus.iteration_stats.back().obj;
280
- } else {
281
- FAISS_THROW_MSG("train type not supported");
282
323
  }
283
324
  clustering_time += (getmillisecs() - t1) / 1000;
284
325
 
@@ -350,6 +391,19 @@ void ResidualQuantizer::train(size_t n, const float* x) {
350
391
  cur_beam_size = new_beam_size;
351
392
  }
352
393
 
394
+ is_trained = true;
395
+
396
+ if (train_type & Train_refine_codebook) {
397
+ for (int iter = 0; iter < niter_codebook_refine; iter++) {
398
+ if (verbose) {
399
+ printf("re-estimating the codebooks to minimize "
400
+ "quantization errors (iter %d).\n",
401
+ iter);
402
+ }
403
+ retrain_AQ_codebook(n, x);
404
+ }
405
+ }
406
+
353
407
  // find min and max norms
354
408
  std::vector<float> norms(n);
355
409
 
@@ -359,33 +413,128 @@ void ResidualQuantizer::train(size_t n, const float* x) {
359
413
  }
360
414
 
361
415
  // fvec_norms_L2sqr(norms.data(), x, d, n);
416
+ train_norm(n, norms.data());
417
+
418
+ if (!(train_type & Skip_codebook_tables)) {
419
+ compute_codebook_tables();
420
+ }
421
+ }
422
+
423
+ float ResidualQuantizer::retrain_AQ_codebook(size_t n, const float* x) {
424
+ FAISS_THROW_IF_NOT_MSG(n >= total_codebook_size, "too few training points");
425
+
426
+ if (verbose) {
427
+ printf(" encoding %zd training vectors\n", n);
428
+ }
429
+ std::vector<uint8_t> codes(n * code_size);
430
+ compute_codes(x, codes.data(), n);
362
431
 
363
- norm_min = HUGE_VALF;
364
- norm_max = -HUGE_VALF;
365
- for (idx_t i = 0; i < n; i++) {
366
- if (norms[i] < norm_min) {
367
- norm_min = norms[i];
432
+ // compute reconstruction error
433
+ float input_recons_error;
434
+ {
435
+ std::vector<float> x_recons(n * d);
436
+ decode(codes.data(), x_recons.data(), n);
437
+ input_recons_error = fvec_L2sqr(x, x_recons.data(), n * d);
438
+ if (verbose) {
439
+ printf(" input quantization error %g\n", input_recons_error);
368
440
  }
369
- if (norms[i] > norm_max) {
370
- norm_max = norms[i];
441
+ }
442
+
443
+ // build matrix of the linear system
444
+ std::vector<float> C(n * total_codebook_size);
445
+ for (size_t i = 0; i < n; i++) {
446
+ BitstringReader bsr(codes.data() + i * code_size, code_size);
447
+ for (int m = 0; m < M; m++) {
448
+ int idx = bsr.read(nbits[m]);
449
+ C[i + (codebook_offsets[m] + idx) * n] = 1;
450
+ }
451
+ }
452
+
453
+ // transpose training vectors
454
+ std::vector<float> xt(n * d);
455
+
456
+ for (size_t i = 0; i < n; i++) {
457
+ for (size_t j = 0; j < d; j++) {
458
+ xt[j * n + i] = x[i * d + j];
371
459
  }
372
460
  }
373
461
 
374
- if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
375
- size_t k = (1 << 8);
376
- if (search_type == ST_norm_cqint4) {
377
- k = (1 << 4);
462
+ { // solve least squares
463
+ FINTEGER lwork = -1;
464
+ FINTEGER di = d, ni = n, tcsi = total_codebook_size;
465
+ FINTEGER info = -1, rank = -1;
466
+
467
+ float rcond = 1e-4; // this is an important parameter because the code
468
+ // matrix can be rank deficient for small problems,
469
+ // the default rcond=-1 does not work
470
+ float worksize;
471
+ std::vector<float> sing_vals(total_codebook_size);
472
+ FINTEGER nlvl = 1000; // formula is a bit convoluted so let's take an
473
+ // upper bound
474
+ std::vector<FINTEGER> iwork(
475
+ 3 * total_codebook_size * nlvl + 11 * total_codebook_size);
476
+
477
+ // worksize query
478
+ sgelsd_(&ni,
479
+ &tcsi,
480
+ &di,
481
+ C.data(),
482
+ &ni,
483
+ xt.data(),
484
+ &ni,
485
+ sing_vals.data(),
486
+ &rcond,
487
+ &rank,
488
+ &worksize,
489
+ &lwork,
490
+ iwork.data(),
491
+ &info);
492
+ FAISS_THROW_IF_NOT(info == 0);
493
+
494
+ lwork = worksize;
495
+ std::vector<float> work(lwork);
496
+ // actual call
497
+ sgelsd_(&ni,
498
+ &tcsi,
499
+ &di,
500
+ C.data(),
501
+ &ni,
502
+ xt.data(),
503
+ &ni,
504
+ sing_vals.data(),
505
+ &rcond,
506
+ &rank,
507
+ work.data(),
508
+ &lwork,
509
+ iwork.data(),
510
+ &info);
511
+ FAISS_THROW_IF_NOT_FMT(info == 0, "SGELS returned info=%d", int(info));
512
+ if (verbose) {
513
+ printf(" sgelsd rank=%d/%d\n",
514
+ int(rank),
515
+ int(total_codebook_size));
378
516
  }
379
- Clustering1D clus(k);
380
- clus.train_exact(n, norms.data());
381
- qnorm.add(clus.k, clus.centroids.data());
382
517
  }
383
518
 
384
- is_trained = true;
519
+ // result is in xt, re-transpose to codebook
385
520
 
386
- if (!(train_type & Skip_codebook_tables)) {
387
- compute_codebook_tables();
521
+ for (size_t i = 0; i < total_codebook_size; i++) {
522
+ for (size_t j = 0; j < d; j++) {
523
+ codebooks[i * d + j] = xt[j * n + i];
524
+ FAISS_THROW_IF_NOT(std::isfinite(codebooks[i * d + j]));
525
+ }
526
+ }
527
+
528
+ float output_recons_error = 0;
529
+ for (size_t j = 0; j < d; j++) {
530
+ output_recons_error += fvec_norm_L2sqr(
531
+ xt.data() + total_codebook_size + n * j,
532
+ n - total_codebook_size);
388
533
  }
534
+ if (verbose) {
535
+ printf(" output quantization error %g\n", output_recons_error);
536
+ }
537
+ return output_recons_error;
389
538
  }
390
539
 
391
540
  size_t ResidualQuantizer::memory_per_point(int beam_size) const {
@@ -400,10 +549,11 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
400
549
  return mem;
401
550
  }
402
551
 
403
- void ResidualQuantizer::compute_codes(
552
+ void ResidualQuantizer::compute_codes_add_centroids(
404
553
  const float* x,
405
554
  uint8_t* codes_out,
406
- size_t n) const {
555
+ size_t n,
556
+ const float* centroids) const {
407
557
  FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet.");
408
558
 
409
559
  size_t mem = memory_per_point();
@@ -415,7 +565,12 @@ void ResidualQuantizer::compute_codes(
415
565
  }
416
566
  for (size_t i0 = 0; i0 < n; i0 += bs) {
417
567
  size_t i1 = std::min(n, i0 + bs);
418
- compute_codes(x + i0 * d, codes_out + i0 * code_size, i1 - i0);
568
+ const float* cent = nullptr;
569
+ if (centroids != nullptr) {
570
+ cent = centroids + i0 * d;
571
+ }
572
+ compute_codes_add_centroids(
573
+ x + i0 * d, codes_out + i0 * code_size, i1 - i0, cent);
419
574
  }
420
575
  return;
421
576
  }
@@ -489,7 +644,8 @@ void ResidualQuantizer::compute_codes(
489
644
  codes.data(),
490
645
  codes_out,
491
646
  M * max_beam_size,
492
- norms.size() > 0 ? norms.data() : nullptr);
647
+ norms.size() > 0 ? norms.data() : nullptr,
648
+ centroids);
493
649
  }
494
650
 
495
651
  void ResidualQuantizer::refine_beam(
@@ -24,25 +24,31 @@ namespace faiss {
24
24
 
25
25
  struct ResidualQuantizer : AdditiveQuantizer {
26
26
  /// initialization
27
- enum train_type_t {
28
- Train_default = 0, ///< regular k-means
29
- Train_progressive_dim = 1, ///< progressive dim clustering
30
- Train_default_Train_top_beam = 1024,
31
- Train_progressive_dim_Train_top_beam = 1025,
32
- Train_default_Skip_codebook_tables = 2048,
33
- Train_progressive_dim_Skip_codebook_tables = 2049,
34
- Train_default_Train_top_beam_Skip_codebook_tables = 3072,
35
- Train_progressive_dim_Train_top_beam_Skip_codebook_tables = 3073,
36
- };
37
27
 
28
+ // Was enum but that does not work so well with bitmasks
29
+ using train_type_t = int;
30
+
31
+ /// Binary or of the Train_* flags below
38
32
  train_type_t train_type;
39
33
 
40
- // set this bit on train_type if beam is to be trained only on the
41
- // first element of the beam (faster but less accurate)
34
+ /// regular k-means (minimal amount of computation)
35
+ static const int Train_default = 0;
36
+
37
+ /// progressive dim clustering (set by default)
38
+ static const int Train_progressive_dim = 1;
39
+
40
+ /// do a few iterations of codebook refinement after first level estimation
41
+ static const int Train_refine_codebook = 2;
42
+
43
+ /// number of iterations for codebook refinement.
44
+ int niter_codebook_refine;
45
+
46
+ /** set this bit on train_type if beam is to be trained only on the
47
+ * first element of the beam (faster but less accurate) */
42
48
  static const int Train_top_beam = 1024;
43
49
 
44
- // set this bit to not autmatically compute the codebook tables
45
- // after training
50
+ /** set this bit to *not* autmatically compute the codebook tables
51
+ * after training */
46
52
  static const int Skip_codebook_tables = 2048;
47
53
 
48
54
  /// beam size used for training and for encoding
@@ -51,10 +57,6 @@ struct ResidualQuantizer : AdditiveQuantizer {
51
57
  /// use LUT for beam search
52
58
  int use_beam_LUT;
53
59
 
54
- /// distance matrixes with beam search can get large, so use this
55
- /// to batch computations at encoding time.
56
- size_t max_mem_distances;
57
-
58
60
  /// clustering parameters
59
61
  ProgressiveDimClusteringParameters cp;
60
62
 
@@ -74,15 +76,33 @@ struct ResidualQuantizer : AdditiveQuantizer {
74
76
 
75
77
  ResidualQuantizer();
76
78
 
77
- // Train the residual quantizer
79
+ /// Train the residual quantizer
78
80
  void train(size_t n, const float* x) override;
79
81
 
82
+ /// Copy the M codebook levels from other, starting from skip_M
83
+ void initialize_from(const ResidualQuantizer& other, int skip_M = 0);
84
+
85
+ /** Encode the vectors and compute codebook that minimizes the quantization
86
+ * error on these codes
87
+ *
88
+ * @param x training vectors, size n * d
89
+ * @param n nb of training vectors, n >= total_codebook_size
90
+ * @return returns quantization error for the new codebook with old
91
+ * codes
92
+ */
93
+ float retrain_AQ_codebook(size_t n, const float* x);
94
+
80
95
  /** Encode a set of vectors
81
96
  *
82
97
  * @param x vectors to encode, size n * d
83
98
  * @param codes output codes, size n * code_size
99
+ * @param centroids centroids to be added to x, size n * d
84
100
  */
85
- void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
101
+ void compute_codes_add_centroids(
102
+ const float* x,
103
+ uint8_t* codes,
104
+ size_t n,
105
+ const float* centroids = nullptr) const override;
86
106
 
87
107
  /** lower-level encode function
88
108
  *
@@ -413,4 +413,100 @@ struct RangeSearchResultHandler {
413
413
  }
414
414
  };
415
415
 
416
+ /*****************************************************************
417
+ * Single best result handler.
418
+ * Tracks the only best result, thus avoiding storing
419
+ * some temporary data in memory.
420
+ *****************************************************************/
421
+
422
+ template <class C>
423
+ struct SingleBestResultHandler {
424
+ using T = typename C::T;
425
+ using TI = typename C::TI;
426
+
427
+ int nq;
428
+ // contains exactly nq elements
429
+ T* dis_tab;
430
+ // contains exactly nq elements
431
+ TI* ids_tab;
432
+
433
+ SingleBestResultHandler(size_t nq, T* dis_tab, TI* ids_tab)
434
+ : nq(nq), dis_tab(dis_tab), ids_tab(ids_tab) {}
435
+
436
+ struct SingleResultHandler {
437
+ SingleBestResultHandler& hr;
438
+
439
+ T min_dis;
440
+ TI min_idx;
441
+ size_t current_idx = 0;
442
+
443
+ SingleResultHandler(SingleBestResultHandler& hr) : hr(hr) {}
444
+
445
+ /// begin results for query # i
446
+ void begin(const size_t current_idx) {
447
+ this->current_idx = current_idx;
448
+ min_dis = HUGE_VALF;
449
+ min_idx = 0;
450
+ }
451
+
452
+ /// add one result for query i
453
+ void add_result(T dis, TI idx) {
454
+ if (C::cmp(min_dis, dis)) {
455
+ min_dis = dis;
456
+ min_idx = idx;
457
+ }
458
+ }
459
+
460
+ /// series of results for query i is done
461
+ void end() {
462
+ hr.dis_tab[current_idx] = min_dis;
463
+ hr.ids_tab[current_idx] = min_idx;
464
+ }
465
+ };
466
+
467
+ size_t i0, i1;
468
+
469
+ /// begin
470
+ void begin_multiple(size_t i0, size_t i1) {
471
+ this->i0 = i0;
472
+ this->i1 = i1;
473
+
474
+ for (size_t i = i0; i < i1; i++) {
475
+ this->dis_tab[i] = HUGE_VALF;
476
+ }
477
+ }
478
+
479
+ /// add results for query i0..i1 and j0..j1
480
+ void add_results(size_t j0, size_t j1, const T* dis_tab) {
481
+ for (int64_t i = i0; i < i1; i++) {
482
+ const T* dis_tab_i = dis_tab + (j1 - j0) * (i - i0) - j0;
483
+
484
+ auto& min_distance = this->dis_tab[i];
485
+ auto& min_index = this->ids_tab[i];
486
+
487
+ for (size_t j = j0; j < j1; j++) {
488
+ const T distance = dis_tab_i[j];
489
+
490
+ if (C::cmp(min_distance, distance)) {
491
+ min_distance = distance;
492
+ min_index = j;
493
+ }
494
+ }
495
+ }
496
+ }
497
+
498
+ void add_result(const size_t i, const T dis, const TI idx) {
499
+ auto& min_distance = this->dis_tab[i];
500
+ auto& min_index = this->ids_tab[i];
501
+
502
+ if (C::cmp(min_distance, dis)) {
503
+ min_distance = dis;
504
+ min_index = idx;
505
+ }
506
+ }
507
+
508
+ /// series of results for queries i0..i1 is done
509
+ void end_multiple() {}
510
+ };
511
+
416
512
  } // namespace faiss