faiss 0.2.6 → 0.2.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -23,6 +23,10 @@
23
23
  #include <faiss/utils/hamming.h>
24
24
  #include <faiss/utils/utils.h>
25
25
 
26
+ #include <faiss/utils/simdlib.h>
27
+
28
+ #include <faiss/utils/approx_topk/approx_topk.h>
29
+
26
30
  extern "C" {
27
31
 
28
32
  // general matrix multiplication
@@ -63,12 +67,7 @@ int sgelsd_(
63
67
 
64
68
  namespace faiss {
65
69
 
66
- ResidualQuantizer::ResidualQuantizer()
67
- : train_type(Train_progressive_dim),
68
- niter_codebook_refine(5),
69
- max_beam_size(5),
70
- use_beam_LUT(0),
71
- assign_index_factory(nullptr) {
70
+ ResidualQuantizer::ResidualQuantizer() {
72
71
  d = 0;
73
72
  M = 0;
74
73
  verbose = false;
@@ -139,12 +138,11 @@ void beam_search_encode_step(
139
138
  int32_t* new_codes, /// size (n, new_beam_size, m + 1)
140
139
  float* new_residuals, /// size (n, new_beam_size, d)
141
140
  float* new_distances, /// size (n, new_beam_size)
142
- Index* assign_index) {
141
+ Index* assign_index,
142
+ ApproxTopK_mode_t approx_topk_mode) {
143
143
  // we have to fill in the whole output matrix
144
144
  FAISS_THROW_IF_NOT(new_beam_size <= beam_size * K);
145
145
 
146
- using idx_t = Index::idx_t;
147
-
148
146
  std::vector<float> cent_distances;
149
147
  std::vector<idx_t> cent_ids;
150
148
 
@@ -230,15 +228,36 @@ void beam_search_encode_step(
230
228
  new_distances_i[i] = C::neutral();
231
229
  }
232
230
  std::vector<int> perm(new_beam_size, -1);
233
- heap_addn<C>(
234
- new_beam_size,
235
- new_distances_i,
236
- perm.data(),
237
- cent_distances_i,
238
- nullptr,
239
- beam_size * K);
231
+
232
+ #define HANDLE_APPROX(NB, BD) \
233
+ case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
234
+ HeapWithBuckets<C, NB, BD>::bs_addn( \
235
+ beam_size, \
236
+ K, \
237
+ cent_distances_i, \
238
+ new_beam_size, \
239
+ new_distances_i, \
240
+ perm.data()); \
241
+ break;
242
+
243
+ switch (approx_topk_mode) {
244
+ HANDLE_APPROX(8, 3)
245
+ HANDLE_APPROX(8, 2)
246
+ HANDLE_APPROX(16, 2)
247
+ HANDLE_APPROX(32, 2)
248
+ default:
249
+ heap_addn<C>(
250
+ new_beam_size,
251
+ new_distances_i,
252
+ perm.data(),
253
+ cent_distances_i,
254
+ nullptr,
255
+ beam_size * K);
256
+ }
240
257
  heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
241
258
 
259
+ #undef HANDLE_APPROX
260
+
242
261
  for (int j = 0; j < new_beam_size; j++) {
243
262
  int js = perm[j] / K;
244
263
  int ls = perm[j] % K;
@@ -364,7 +383,8 @@ void ResidualQuantizer::train(size_t n, const float* x) {
364
383
  new_codes.data() + i0 * new_beam_size * (m + 1),
365
384
  new_residuals.data() + i0 * new_beam_size * d,
366
385
  new_distances.data() + i0 * new_beam_size,
367
- assign_index.get());
386
+ assign_index.get(),
387
+ approx_topk_mode);
368
388
  }
369
389
  codes.swap(new_codes);
370
390
  residuals.swap(new_residuals);
@@ -544,11 +564,185 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
544
564
  size_t mem;
545
565
  mem = beam_size * d * 2 * sizeof(float); // size for 2 beams at a time
546
566
  mem += beam_size * beam_size *
547
- (sizeof(float) +
548
- sizeof(Index::idx_t)); // size for 1 beam search result
567
+ (sizeof(float) + sizeof(idx_t)); // size for 1 beam search result
549
568
  return mem;
550
569
  }
551
570
 
571
+ // a namespace full of preallocated buffers
572
+ namespace {
573
+
574
+ // Preallocated memory chunk for refine_beam_mp() call
575
+ struct RefineBeamMemoryPool {
576
+ std::vector<int32_t> new_codes;
577
+ std::vector<float> new_residuals;
578
+
579
+ std::vector<float> residuals;
580
+ std::vector<int32_t> codes;
581
+ std::vector<float> distances;
582
+ };
583
+
584
+ // Preallocated memory chunk for refine_beam_LUT_mp() call
585
+ struct RefineBeamLUTMemoryPool {
586
+ std::vector<int32_t> new_codes;
587
+ std::vector<float> new_distances;
588
+
589
+ std::vector<int32_t> codes;
590
+ std::vector<float> distances;
591
+ };
592
+
593
+ // this is for use_beam_LUT == 0 in compute_codes_add_centroids_mp_lut0() call
594
+ struct ComputeCodesAddCentroidsLUT0MemoryPool {
595
+ std::vector<int32_t> codes;
596
+ std::vector<float> norms;
597
+ std::vector<float> distances;
598
+ std::vector<float> residuals;
599
+ RefineBeamMemoryPool refine_beam_pool;
600
+ };
601
+
602
+ // this is for use_beam_LUT == 1 in compute_codes_add_centroids_mp_lut1() call
603
+ struct ComputeCodesAddCentroidsLUT1MemoryPool {
604
+ std::vector<int32_t> codes;
605
+ std::vector<float> distances;
606
+ std::vector<float> query_norms;
607
+ std::vector<float> query_cp;
608
+ std::vector<float> residuals;
609
+ RefineBeamLUTMemoryPool refine_beam_lut_pool;
610
+ };
611
+
612
+ } // namespace
613
+
614
+ // forward declaration
615
+ void refine_beam_mp(
616
+ const ResidualQuantizer& rq,
617
+ size_t n,
618
+ size_t beam_size,
619
+ const float* x,
620
+ int out_beam_size,
621
+ int32_t* out_codes,
622
+ float* out_residuals,
623
+ float* out_distances,
624
+ RefineBeamMemoryPool& pool);
625
+
626
+ // forward declaration
627
+ void refine_beam_LUT_mp(
628
+ const ResidualQuantizer& rq,
629
+ size_t n,
630
+ const float* query_norms, // size n
631
+ const float* query_cp, //
632
+ int out_beam_size,
633
+ int32_t* out_codes,
634
+ float* out_distances,
635
+ RefineBeamLUTMemoryPool& pool);
636
+
637
+ // this is for use_beam_LUT == 0
638
+ void compute_codes_add_centroids_mp_lut0(
639
+ const ResidualQuantizer& rq,
640
+ const float* x,
641
+ uint8_t* codes_out,
642
+ size_t n,
643
+ const float* centroids,
644
+ ComputeCodesAddCentroidsLUT0MemoryPool& pool) {
645
+ pool.codes.resize(rq.max_beam_size * rq.M * n);
646
+ pool.distances.resize(rq.max_beam_size * n);
647
+
648
+ pool.residuals.resize(rq.max_beam_size * n * rq.d);
649
+
650
+ refine_beam_mp(
651
+ rq,
652
+ n,
653
+ 1,
654
+ x,
655
+ rq.max_beam_size,
656
+ pool.codes.data(),
657
+ pool.residuals.data(),
658
+ pool.distances.data(),
659
+ pool.refine_beam_pool);
660
+
661
+ if (rq.search_type == ResidualQuantizer::ST_norm_float ||
662
+ rq.search_type == ResidualQuantizer::ST_norm_qint8 ||
663
+ rq.search_type == ResidualQuantizer::ST_norm_qint4) {
664
+ pool.norms.resize(n);
665
+ // recover the norms of reconstruction as
666
+ // || original_vector - residual ||^2
667
+ for (size_t i = 0; i < n; i++) {
668
+ pool.norms[i] = fvec_L2sqr(
669
+ x + i * rq.d,
670
+ pool.residuals.data() + i * rq.max_beam_size * rq.d,
671
+ rq.d);
672
+ }
673
+ }
674
+
675
+ // pack only the first code of the beam
676
+ // (hence the ld_codes=M * max_beam_size)
677
+ rq.pack_codes(
678
+ n,
679
+ pool.codes.data(),
680
+ codes_out,
681
+ rq.M * rq.max_beam_size,
682
+ (pool.norms.size() > 0) ? pool.norms.data() : nullptr,
683
+ centroids);
684
+ }
685
+
686
+ // use_beam_LUT == 1
687
+ void compute_codes_add_centroids_mp_lut1(
688
+ const ResidualQuantizer& rq,
689
+ const float* x,
690
+ uint8_t* codes_out,
691
+ size_t n,
692
+ const float* centroids,
693
+ ComputeCodesAddCentroidsLUT1MemoryPool& pool) {
694
+ //
695
+ pool.codes.resize(rq.max_beam_size * rq.M * n);
696
+ pool.distances.resize(rq.max_beam_size * n);
697
+
698
+ FAISS_THROW_IF_NOT_MSG(
699
+ rq.codebook_cross_products.size() ==
700
+ rq.total_codebook_size * rq.total_codebook_size,
701
+ "call compute_codebook_tables first");
702
+
703
+ pool.query_norms.resize(n);
704
+ fvec_norms_L2sqr(pool.query_norms.data(), x, rq.d, n);
705
+
706
+ pool.query_cp.resize(n * rq.total_codebook_size);
707
+ {
708
+ FINTEGER ti = rq.total_codebook_size, di = rq.d, ni = n;
709
+ float zero = 0, one = 1;
710
+ sgemm_("Transposed",
711
+ "Not transposed",
712
+ &ti,
713
+ &ni,
714
+ &di,
715
+ &one,
716
+ rq.codebooks.data(),
717
+ &di,
718
+ x,
719
+ &di,
720
+ &zero,
721
+ pool.query_cp.data(),
722
+ &ti);
723
+ }
724
+
725
+ refine_beam_LUT_mp(
726
+ rq,
727
+ n,
728
+ pool.query_norms.data(),
729
+ pool.query_cp.data(),
730
+ rq.max_beam_size,
731
+ pool.codes.data(),
732
+ pool.distances.data(),
733
+ pool.refine_beam_lut_pool);
734
+
735
+ // pack only the first code of the beam
736
+ // (hence the ld_codes=M * max_beam_size)
737
+ rq.pack_codes(
738
+ n,
739
+ pool.codes.data(),
740
+ codes_out,
741
+ rq.M * rq.max_beam_size,
742
+ nullptr,
743
+ centroids);
744
+ }
745
+
552
746
  void ResidualQuantizer::compute_codes_add_centroids(
553
747
  const float* x,
554
748
  uint8_t* codes_out,
@@ -556,184 +750,212 @@ void ResidualQuantizer::compute_codes_add_centroids(
556
750
  const float* centroids) const {
557
751
  FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet.");
558
752
 
753
+ //
559
754
  size_t mem = memory_per_point();
560
- if (n > 1 && mem * n > max_mem_distances) {
561
- // then split queries to reduce temp memory
562
- size_t bs = max_mem_distances / mem;
563
- if (bs == 0) {
564
- bs = 1; // otherwise we can't do much
565
- }
566
- for (size_t i0 = 0; i0 < n; i0 += bs) {
567
- size_t i1 = std::min(n, i0 + bs);
568
- const float* cent = nullptr;
569
- if (centroids != nullptr) {
570
- cent = centroids + i0 * d;
571
- }
572
- compute_codes_add_centroids(
573
- x + i0 * d, codes_out + i0 * code_size, i1 - i0, cent);
574
- }
575
- return;
755
+
756
+ size_t bs = max_mem_distances / mem;
757
+ if (bs == 0) {
758
+ bs = 1; // otherwise we can't do much
576
759
  }
577
760
 
578
- std::vector<int32_t> codes(max_beam_size * M * n);
579
- std::vector<float> norms;
580
- std::vector<float> distances(max_beam_size * n);
761
+ // prepare memory pools
762
+ ComputeCodesAddCentroidsLUT0MemoryPool pool0;
763
+ ComputeCodesAddCentroidsLUT1MemoryPool pool1;
581
764
 
582
- if (use_beam_LUT == 0) {
583
- std::vector<float> residuals(max_beam_size * n * d);
765
+ for (size_t i0 = 0; i0 < n; i0 += bs) {
766
+ size_t i1 = std::min(n, i0 + bs);
767
+ const float* cent = nullptr;
768
+ if (centroids != nullptr) {
769
+ cent = centroids + i0 * d;
770
+ }
584
771
 
585
- refine_beam(
586
- n,
587
- 1,
588
- x,
589
- max_beam_size,
590
- codes.data(),
591
- residuals.data(),
592
- distances.data());
593
-
594
- if (search_type == ST_norm_float || search_type == ST_norm_qint8 ||
595
- search_type == ST_norm_qint4) {
596
- norms.resize(n);
597
- // recover the norms of reconstruction as
598
- // || original_vector - residual ||^2
599
- for (size_t i = 0; i < n; i++) {
600
- norms[i] = fvec_L2sqr(
601
- x + i * d, residuals.data() + i * max_beam_size * d, d);
602
- }
772
+ // compute_codes_add_centroids(
773
+ // x + i0 * d,
774
+ // codes_out + i0 * code_size,
775
+ // i1 - i0,
776
+ // cent);
777
+ if (use_beam_LUT == 0) {
778
+ compute_codes_add_centroids_mp_lut0(
779
+ *this,
780
+ x + i0 * d,
781
+ codes_out + i0 * code_size,
782
+ i1 - i0,
783
+ cent,
784
+ pool0);
785
+ } else if (use_beam_LUT == 1) {
786
+ compute_codes_add_centroids_mp_lut1(
787
+ *this,
788
+ x + i0 * d,
789
+ codes_out + i0 * code_size,
790
+ i1 - i0,
791
+ cent,
792
+ pool1);
603
793
  }
604
- } else if (use_beam_LUT == 1) {
605
- FAISS_THROW_IF_NOT_MSG(
606
- codebook_cross_products.size() ==
607
- total_codebook_size * total_codebook_size,
608
- "call compute_codebook_tables first");
609
-
610
- std::vector<float> query_norms(n);
611
- fvec_norms_L2sqr(query_norms.data(), x, d, n);
612
-
613
- std::vector<float> query_cp(n * total_codebook_size);
614
- {
615
- FINTEGER ti = total_codebook_size, di = d, ni = n;
616
- float zero = 0, one = 1;
617
- sgemm_("Transposed",
618
- "Not transposed",
619
- &ti,
620
- &ni,
621
- &di,
622
- &one,
623
- codebooks.data(),
624
- &di,
625
- x,
626
- &di,
627
- &zero,
628
- query_cp.data(),
629
- &ti);
630
- }
631
-
632
- refine_beam_LUT(
633
- n,
634
- query_norms.data(),
635
- query_cp.data(),
636
- max_beam_size,
637
- codes.data(),
638
- distances.data());
639
- }
640
- // pack only the first code of the beam (hence the ld_codes=M *
641
- // max_beam_size)
642
- pack_codes(
643
- n,
644
- codes.data(),
645
- codes_out,
646
- M * max_beam_size,
647
- norms.size() > 0 ? norms.data() : nullptr,
648
- centroids);
794
+ }
649
795
  }
650
796
 
651
- void ResidualQuantizer::refine_beam(
797
+ void refine_beam_mp(
798
+ const ResidualQuantizer& rq,
652
799
  size_t n,
653
800
  size_t beam_size,
654
801
  const float* x,
655
802
  int out_beam_size,
656
803
  int32_t* out_codes,
657
804
  float* out_residuals,
658
- float* out_distances) const {
805
+ float* out_distances,
806
+ RefineBeamMemoryPool& pool) {
659
807
  int cur_beam_size = beam_size;
660
808
 
661
- std::vector<float> residuals(x, x + n * d * beam_size);
662
- std::vector<int32_t> codes;
663
- std::vector<float> distances;
664
809
  double t0 = getmillisecs();
665
810
 
811
+ // find the max_beam_size
812
+ int max_beam_size = 0;
813
+ {
814
+ int tmp_beam_size = cur_beam_size;
815
+ for (int m = 0; m < rq.M; m++) {
816
+ int K = 1 << rq.nbits[m];
817
+ int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
818
+ tmp_beam_size = new_beam_size;
819
+
820
+ if (max_beam_size < new_beam_size) {
821
+ max_beam_size = new_beam_size;
822
+ }
823
+ }
824
+ }
825
+
826
+ // preallocate buffers
827
+ pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
828
+ pool.new_residuals.resize(n * max_beam_size * rq.d);
829
+
830
+ pool.codes.resize(n * max_beam_size * (rq.M + 1));
831
+ pool.distances.resize(n * max_beam_size);
832
+ pool.residuals.resize(n * rq.d * max_beam_size);
833
+
834
+ for (size_t i = 0; i < n * rq.d * beam_size; i++) {
835
+ pool.residuals[i] = x[i];
836
+ }
837
+
838
+ // set up pointers to buffers
839
+ int32_t* __restrict codes_ptr = pool.codes.data();
840
+ float* __restrict residuals_ptr = pool.residuals.data();
841
+
842
+ int32_t* __restrict new_codes_ptr = pool.new_codes.data();
843
+ float* __restrict new_residuals_ptr = pool.new_residuals.data();
844
+
845
+ // index
666
846
  std::unique_ptr<Index> assign_index;
667
- if (assign_index_factory) {
668
- assign_index.reset((*assign_index_factory)(d));
847
+ if (rq.assign_index_factory) {
848
+ assign_index.reset((*rq.assign_index_factory)(rq.d));
669
849
  } else {
670
- assign_index.reset(new IndexFlatL2(d));
850
+ assign_index.reset(new IndexFlatL2(rq.d));
671
851
  }
672
852
 
673
- for (int m = 0; m < M; m++) {
674
- int K = 1 << nbits[m];
853
+ // main loop
854
+ size_t codes_size = 0;
855
+ size_t distances_size = 0;
856
+ size_t residuals_size = 0;
675
857
 
676
- const float* codebooks_m =
677
- this->codebooks.data() + codebook_offsets[m] * d;
858
+ for (int m = 0; m < rq.M; m++) {
859
+ int K = 1 << rq.nbits[m];
678
860
 
679
- int new_beam_size = std::min(cur_beam_size * K, out_beam_size);
861
+ const float* __restrict codebooks_m =
862
+ rq.codebooks.data() + rq.codebook_offsets[m] * rq.d;
680
863
 
681
- std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
682
- std::vector<float> new_residuals(n * new_beam_size * d);
683
- distances.resize(n * new_beam_size);
864
+ const int new_beam_size = std::min(cur_beam_size * K, out_beam_size);
865
+
866
+ codes_size = n * new_beam_size * (m + 1);
867
+ residuals_size = n * new_beam_size * rq.d;
868
+ distances_size = n * new_beam_size;
684
869
 
685
870
  beam_search_encode_step(
686
- d,
871
+ rq.d,
687
872
  K,
688
873
  codebooks_m,
689
874
  n,
690
875
  cur_beam_size,
691
- residuals.data(),
876
+ // residuals.data(),
877
+ residuals_ptr,
692
878
  m,
693
- codes.data(),
879
+ // codes.data(),
880
+ codes_ptr,
694
881
  new_beam_size,
695
- new_codes.data(),
696
- new_residuals.data(),
697
- distances.data(),
698
- assign_index.get());
882
+ // new_codes.data(),
883
+ new_codes_ptr,
884
+ // new_residuals.data(),
885
+ new_residuals_ptr,
886
+ pool.distances.data(),
887
+ assign_index.get(),
888
+ rq.approx_topk_mode);
699
889
 
700
890
  assign_index->reset();
701
891
 
702
- codes.swap(new_codes);
703
- residuals.swap(new_residuals);
892
+ std::swap(codes_ptr, new_codes_ptr);
893
+ std::swap(residuals_ptr, new_residuals_ptr);
704
894
 
705
895
  cur_beam_size = new_beam_size;
706
896
 
707
- if (verbose) {
897
+ if (rq.verbose) {
708
898
  float sum_distances = 0;
709
- for (int j = 0; j < distances.size(); j++) {
710
- sum_distances += distances[j];
899
+ // for (int j = 0; j < distances.size(); j++) {
900
+ // sum_distances += distances[j];
901
+ // }
902
+ for (int j = 0; j < distances_size; j++) {
903
+ sum_distances += pool.distances[j];
711
904
  }
905
+
712
906
  printf("[%.3f s] encode stage %d, %d bits, "
713
907
  "total error %g, beam_size %d\n",
714
908
  (getmillisecs() - t0) / 1000,
715
909
  m,
716
- int(nbits[m]),
910
+ int(rq.nbits[m]),
717
911
  sum_distances,
718
912
  cur_beam_size);
719
913
  }
720
914
  }
721
915
 
722
916
  if (out_codes) {
723
- memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
917
+ // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
918
+ memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
724
919
  }
725
920
  if (out_residuals) {
921
+ // memcpy(out_residuals,
922
+ // residuals.data(),
923
+ // residuals.size() * sizeof(residuals[0]));
726
924
  memcpy(out_residuals,
727
- residuals.data(),
728
- residuals.size() * sizeof(residuals[0]));
925
+ residuals_ptr,
926
+ residuals_size * sizeof(*residuals_ptr));
729
927
  }
730
928
  if (out_distances) {
929
+ // memcpy(out_distances,
930
+ // distances.data(),
931
+ // distances.size() * sizeof(distances[0]));
731
932
  memcpy(out_distances,
732
- distances.data(),
733
- distances.size() * sizeof(distances[0]));
933
+ pool.distances.data(),
934
+ distances_size * sizeof(pool.distances[0]));
734
935
  }
735
936
  }
736
937
 
938
+ void ResidualQuantizer::refine_beam(
939
+ size_t n,
940
+ size_t beam_size,
941
+ const float* x,
942
+ int out_beam_size,
943
+ int32_t* out_codes,
944
+ float* out_residuals,
945
+ float* out_distances) const {
946
+ RefineBeamMemoryPool pool;
947
+ refine_beam_mp(
948
+ *this,
949
+ n,
950
+ beam_size,
951
+ x,
952
+ out_beam_size,
953
+ out_codes,
954
+ out_residuals,
955
+ out_distances,
956
+ pool);
957
+ }
958
+
737
959
  /*******************************************************************
738
960
  * Functions using the dot products between codebook entries
739
961
  *******************************************************************/
@@ -765,6 +987,186 @@ void ResidualQuantizer::compute_codebook_tables() {
765
987
  }
766
988
  }
767
989
 
990
+ namespace {
991
+
992
+ template <size_t M, size_t NK>
993
+ void accum_and_store_tab(
994
+ const size_t m_offset,
995
+ const float* const __restrict codebook_cross_norms,
996
+ const uint64_t* const __restrict codebook_offsets,
997
+ const int32_t* const __restrict codes_i,
998
+ const size_t b,
999
+ const size_t ldc,
1000
+ const size_t K,
1001
+ float* const __restrict output) {
1002
+ // load pointers into registers
1003
+ const float* cbs[M];
1004
+ for (size_t ij = 0; ij < M; ij++) {
1005
+ const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
1006
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1007
+ }
1008
+
1009
+ // do accumulation in registers using SIMD.
1010
+ // It is possible that compiler may be smart enough so that
1011
+ // this manual SIMD unrolling might be unneeded.
1012
+ #if defined(__AVX2__) || defined(__aarch64__)
1013
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
1014
+
1015
+ // process in chunks of size (8 * NK) floats
1016
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1017
+ simd8float32 regs[NK];
1018
+ for (size_t ik = 0; ik < NK; ik++) {
1019
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
1020
+ }
1021
+
1022
+ for (size_t ij = 1; ij < M; ij++) {
1023
+ for (size_t ik = 0; ik < NK; ik++) {
1024
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1025
+ }
1026
+ }
1027
+
1028
+ // write the result
1029
+ for (size_t ik = 0; ik < NK; ik++) {
1030
+ regs[ik].storeu(output + kk + ik * 8);
1031
+ }
1032
+ }
1033
+ #else
1034
+ const size_t K8 = 0;
1035
+ #endif
1036
+
1037
+ // process leftovers
1038
+ for (size_t kk = K8; kk < K; kk++) {
1039
+ float reg = cbs[0][kk];
1040
+ for (size_t ij = 1; ij < M; ij++) {
1041
+ reg += cbs[ij][kk];
1042
+ }
1043
+ output[b * K + kk] = reg;
1044
+ }
1045
+ }
1046
+
1047
+ template <size_t M, size_t NK>
1048
+ void accum_and_add_tab(
1049
+ const size_t m_offset,
1050
+ const float* const __restrict codebook_cross_norms,
1051
+ const uint64_t* const __restrict codebook_offsets,
1052
+ const int32_t* const __restrict codes_i,
1053
+ const size_t b,
1054
+ const size_t ldc,
1055
+ const size_t K,
1056
+ float* const __restrict output) {
1057
+ // load pointers into registers
1058
+ const float* cbs[M];
1059
+ for (size_t ij = 0; ij < M; ij++) {
1060
+ const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
1061
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1062
+ }
1063
+
1064
+ // do accumulation in registers using SIMD.
1065
+ // It is possible that compiler may be smart enough so that
1066
+ // this manual SIMD unrolling might be unneeded.
1067
+ #if defined(__AVX2__) || defined(__aarch64__)
1068
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
1069
+
1070
+ // process in chunks of size (8 * NK) floats
1071
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1072
+ simd8float32 regs[NK];
1073
+ for (size_t ik = 0; ik < NK; ik++) {
1074
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
1075
+ }
1076
+
1077
+ for (size_t ij = 1; ij < M; ij++) {
1078
+ for (size_t ik = 0; ik < NK; ik++) {
1079
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1080
+ }
1081
+ }
1082
+
1083
+ // write the result
1084
+ for (size_t ik = 0; ik < NK; ik++) {
1085
+ simd8float32 existing(output + kk + ik * 8);
1086
+ existing += regs[ik];
1087
+ existing.storeu(output + kk + ik * 8);
1088
+ }
1089
+ }
1090
+ #else
1091
+ const size_t K8 = 0;
1092
+ #endif
1093
+
1094
+ // process leftovers
1095
+ for (size_t kk = K8; kk < K; kk++) {
1096
+ float reg = cbs[0][kk];
1097
+ for (size_t ij = 1; ij < M; ij++) {
1098
+ reg += cbs[ij][kk];
1099
+ }
1100
+ output[b * K + kk] += reg;
1101
+ }
1102
+ }
1103
+
1104
+ template <size_t M, size_t NK>
1105
+ void accum_and_finalize_tab(
1106
+ const float* const __restrict codebook_cross_norms,
1107
+ const uint64_t* const __restrict codebook_offsets,
1108
+ const int32_t* const __restrict codes_i,
1109
+ const size_t b,
1110
+ const size_t ldc,
1111
+ const size_t K,
1112
+ const float* const __restrict distances_i,
1113
+ const float* const __restrict cd_common,
1114
+ float* const __restrict output) {
1115
+ // load pointers into registers
1116
+ const float* cbs[M];
1117
+ for (size_t ij = 0; ij < M; ij++) {
1118
+ const size_t code = static_cast<size_t>(codes_i[b * M + ij]);
1119
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1120
+ }
1121
+
1122
+ // do accumulation in registers using SIMD.
1123
+ // It is possible that compiler may be smart enough so that
1124
+ // this manual SIMD unrolling might be unneeded.
1125
+ #if defined(__AVX2__) || defined(__aarch64__)
1126
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
1127
+
1128
+ // process in chunks of size (8 * NK) floats
1129
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1130
+ simd8float32 regs[NK];
1131
+ for (size_t ik = 0; ik < NK; ik++) {
1132
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
1133
+ }
1134
+
1135
+ for (size_t ij = 1; ij < M; ij++) {
1136
+ for (size_t ik = 0; ik < NK; ik++) {
1137
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1138
+ }
1139
+ }
1140
+
1141
+ simd8float32 two(2.0f);
1142
+ for (size_t ik = 0; ik < NK; ik++) {
1143
+ // cent_distances[b * K + k] = distances_i[b] + cd_common[k]
1144
+ // + 2 * dp[k];
1145
+
1146
+ simd8float32 common_v(cd_common + kk + ik * 8);
1147
+ common_v = fmadd(two, regs[ik], common_v);
1148
+
1149
+ common_v += simd8float32(distances_i[b]);
1150
+ common_v.storeu(output + b * K + kk + ik * 8);
1151
+ }
1152
+ }
1153
+ #else
1154
+ const size_t K8 = 0;
1155
+ #endif
1156
+
1157
+ // process leftovers
1158
+ for (size_t kk = K8; kk < K; kk++) {
1159
+ float reg = cbs[0][kk];
1160
+ for (size_t ij = 1; ij < M; ij++) {
1161
+ reg += cbs[ij][kk];
1162
+ }
1163
+
1164
+ output[b * K + kk] = distances_i[b] + cd_common[kk] + 2 * reg;
1165
+ }
1166
+ }
1167
+
1168
+ } // namespace
1169
+
768
1170
  void beam_search_encode_step_tab(
769
1171
  size_t K,
770
1172
  size_t n,
@@ -779,12 +1181,13 @@ void beam_search_encode_step_tab(
779
1181
  const int32_t* codes, // n * beam_size * m
780
1182
  const float* distances, // n * beam_size
781
1183
  size_t new_beam_size,
782
- int32_t* new_codes, // n * new_beam_size * (m + 1)
783
- float* new_distances) // n * new_beam_size
1184
+ int32_t* new_codes, // n * new_beam_size * (m + 1)
1185
+ float* new_distances, // n * new_beam_size
1186
+ ApproxTopK_mode_t approx_topk_mode) //
784
1187
  {
785
1188
  FAISS_THROW_IF_NOT(ldc >= K);
786
1189
 
787
- #pragma omp parallel for if (n > 100)
1190
+ #pragma omp parallel for if (n > 100) schedule(dynamic)
788
1191
  for (int64_t i = 0; i < n; i++) {
789
1192
  std::vector<float> cent_distances(beam_size * K);
790
1193
  std::vector<float> cd_common(K);
@@ -797,6 +1200,14 @@ void beam_search_encode_step_tab(
797
1200
  cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
798
1201
  }
799
1202
 
1203
+ /*
1204
+ // This is the baseline implementation. Its primary flaw
1205
+ // that it writes way too many info to the temporary buffer
1206
+ // called dp.
1207
+ //
1208
+ // This baseline code is kept intentionally because it is easy to
1209
+ // understand what an optimized version optimizes exactly.
1210
+ //
800
1211
  for (size_t b = 0; b < beam_size; b++) {
801
1212
  std::vector<float> dp(K);
802
1213
 
@@ -812,6 +1223,117 @@ void beam_search_encode_step_tab(
812
1223
  distances_i[b] + cd_common[k] + 2 * dp[k];
813
1224
  }
814
1225
  }
1226
+ */
1227
+
1228
+ // An optimized implementation that avoids using a temporary buffer
1229
+ // and does the accumulation in registers.
1230
+
1231
+ // Compute a sum of NK AQ codes.
1232
+ #define ACCUM_AND_FINALIZE_TAB(NK) \
1233
+ case NK: \
1234
+ for (size_t b = 0; b < beam_size; b++) { \
1235
+ accum_and_finalize_tab<NK, 4>( \
1236
+ codebook_cross_norms, \
1237
+ codebook_offsets, \
1238
+ codes_i, \
1239
+ b, \
1240
+ ldc, \
1241
+ K, \
1242
+ distances_i, \
1243
+ cd_common.data(), \
1244
+ cent_distances.data()); \
1245
+ } \
1246
+ break;
1247
+
1248
+ // this version contains many switch-case scenarios, but
1249
+ // they won't affect branch predictor.
1250
+ switch (m) {
1251
+ case 0:
1252
+ // trivial case
1253
+ for (size_t b = 0; b < beam_size; b++) {
1254
+ for (size_t k = 0; k < K; k++) {
1255
+ cent_distances[b * K + k] =
1256
+ distances_i[b] + cd_common[k];
1257
+ }
1258
+ }
1259
+ break;
1260
+
1261
+ ACCUM_AND_FINALIZE_TAB(1)
1262
+ ACCUM_AND_FINALIZE_TAB(2)
1263
+ ACCUM_AND_FINALIZE_TAB(3)
1264
+ ACCUM_AND_FINALIZE_TAB(4)
1265
+ ACCUM_AND_FINALIZE_TAB(5)
1266
+ ACCUM_AND_FINALIZE_TAB(6)
1267
+ ACCUM_AND_FINALIZE_TAB(7)
1268
+
1269
+ default: {
1270
+ // m >= 8 case.
1271
+
1272
+ // A temporary buffer has to be used due to the lack of
1273
+ // registers. But we'll try to accumulate up to 8 AQ codes in
1274
+ // registers and issue a single write operation to the buffer,
1275
+ // while the baseline does no accumulation. So, the number of
1276
+ // write operations to the temporary buffer is reduced 8x.
1277
+
1278
+ // allocate a temporary buffer
1279
+ std::vector<float> dp(K);
1280
+
1281
+ for (size_t b = 0; b < beam_size; b++) {
1282
+ // Initialize it. Compute a sum of first 8 AQ codes
1283
+ // because m >= 8 .
1284
+ accum_and_store_tab<8, 4>(
1285
+ m,
1286
+ codebook_cross_norms,
1287
+ codebook_offsets,
1288
+ codes_i,
1289
+ b,
1290
+ ldc,
1291
+ K,
1292
+ dp.data());
1293
+
1294
+ #define ACCUM_AND_ADD_TAB(NK) \
1295
+ case NK: \
1296
+ accum_and_add_tab<NK, 4>( \
1297
+ m, \
1298
+ codebook_cross_norms, \
1299
+ codebook_offsets + im, \
1300
+ codes_i + im, \
1301
+ b, \
1302
+ ldc, \
1303
+ K, \
1304
+ dp.data()); \
1305
+ break;
1306
+
1307
+ // accumulate up to 8 additional AQ codes into
1308
+ // a temporary buffer
1309
+ for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
1310
+ size_t m_left = m - im;
1311
+ if (m_left > 8) {
1312
+ m_left = 8;
1313
+ }
1314
+
1315
+ switch (m_left) {
1316
+ ACCUM_AND_ADD_TAB(1)
1317
+ ACCUM_AND_ADD_TAB(2)
1318
+ ACCUM_AND_ADD_TAB(3)
1319
+ ACCUM_AND_ADD_TAB(4)
1320
+ ACCUM_AND_ADD_TAB(5)
1321
+ ACCUM_AND_ADD_TAB(6)
1322
+ ACCUM_AND_ADD_TAB(7)
1323
+ ACCUM_AND_ADD_TAB(8)
1324
+ }
1325
+ }
1326
+
1327
+ // done. finalize the result
1328
+ for (size_t k = 0; k < K; k++) {
1329
+ cent_distances[b * K + k] =
1330
+ distances_i[b] + cd_common[k] + 2 * dp[k];
1331
+ }
1332
+ }
1333
+ }
1334
+ }
1335
+
1336
+ // the optimized implementation ends here
815
1337
 
816
1338
  using C = CMax<float, int>;
817
1339
  int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
@@ -824,15 +1346,38 @@ void beam_search_encode_step_tab(
824
1346
  new_distances_i[i] = C::neutral();
825
1347
  }
826
1348
  std::vector<int> perm(new_beam_size, -1);
827
- heap_addn<C>(
828
- new_beam_size,
829
- new_distances_i,
830
- perm.data(),
831
- cent_distances_i,
832
- nullptr,
833
- beam_size * K);
1349
+
1350
+ #define HANDLE_APPROX(NB, BD) \
1351
+ case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
1352
+ HeapWithBuckets<C, NB, BD>::bs_addn( \
1353
+ beam_size, \
1354
+ K, \
1355
+ cent_distances_i, \
1356
+ new_beam_size, \
1357
+ new_distances_i, \
1358
+ perm.data()); \
1359
+ break;
1360
+
1361
+ switch (approx_topk_mode) {
1362
+ HANDLE_APPROX(8, 3)
1363
+ HANDLE_APPROX(8, 2)
1364
+ HANDLE_APPROX(16, 2)
1365
+ HANDLE_APPROX(32, 2)
1366
+ default:
1367
+ heap_addn<C>(
1368
+ new_beam_size,
1369
+ new_distances_i,
1370
+ perm.data(),
1371
+ cent_distances_i,
1372
+ nullptr,
1373
+ beam_size * K);
1374
+ break;
1375
+ }
1376
+
834
1377
  heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
835
1378
 
1379
+ #undef HANDLE_APPROX
1380
+
836
1381
  for (int j = 0; j < new_beam_size; j++) {
837
1382
  int js = perm[j] / K;
838
1383
  int ls = perm[j] % K;
@@ -845,70 +1390,147 @@ void beam_search_encode_step_tab(
845
1390
  }
846
1391
  }
847
1392
 
848
- void ResidualQuantizer::refine_beam_LUT(
1393
+ //
1394
+ void refine_beam_LUT_mp(
1395
+ const ResidualQuantizer& rq,
849
1396
  size_t n,
850
1397
  const float* query_norms, // size n
851
1398
  const float* query_cp, //
852
1399
  int out_beam_size,
853
1400
  int32_t* out_codes,
854
- float* out_distances) const {
1401
+ float* out_distances,
1402
+ RefineBeamLUTMemoryPool& pool) {
855
1403
  int beam_size = 1;
856
1404
 
857
- std::vector<int32_t> codes;
858
- std::vector<float> distances(query_norms, query_norms + n);
859
1405
  double t0 = getmillisecs();
860
1406
 
861
- for (int m = 0; m < M; m++) {
862
- int K = 1 << nbits[m];
1407
+ // find the max_beam_size
1408
+ int max_beam_size = 0;
1409
+ {
1410
+ int tmp_beam_size = beam_size;
1411
+ for (int m = 0; m < rq.M; m++) {
1412
+ int K = 1 << rq.nbits[m];
1413
+ int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
1414
+ tmp_beam_size = new_beam_size;
1415
+
1416
+ if (max_beam_size < new_beam_size) {
1417
+ max_beam_size = new_beam_size;
1418
+ }
1419
+ }
1420
+ }
863
1421
 
1422
+ // preallocate buffers
1423
+ pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
1424
+ pool.new_distances.resize(n * max_beam_size);
1425
+
1426
+ pool.codes.resize(n * max_beam_size * (rq.M + 1));
1427
+ pool.distances.resize(n * max_beam_size);
1428
+
1429
+ for (size_t i = 0; i < n; i++) {
1430
+ pool.distances[i] = query_norms[i];
1431
+ }
1432
+
1433
+ // set up pointers to buffers
1434
+ int32_t* __restrict new_codes_ptr = pool.new_codes.data();
1435
+ float* __restrict new_distances_ptr = pool.new_distances.data();
1436
+
1437
+ int32_t* __restrict codes_ptr = pool.codes.data();
1438
+ float* __restrict distances_ptr = pool.distances.data();
1439
+
1440
+ // main loop
1441
+ size_t codes_size = 0;
1442
+ size_t distances_size = 0;
1443
+ for (int m = 0; m < rq.M; m++) {
1444
+ int K = 1 << rq.nbits[m];
1445
+
1446
+ // it is guaranteed that (new_beam_size <= than max_beam_size) ==
1447
+ // true
864
1448
  int new_beam_size = std::min(beam_size * K, out_beam_size);
865
- std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
866
- std::vector<float> new_distances(n * new_beam_size);
1449
+
1450
+ // std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
1451
+ // std::vector<float> new_distances(n * new_beam_size);
1452
+
1453
+ codes_size = n * new_beam_size * (m + 1);
1454
+ distances_size = n * new_beam_size;
867
1455
 
868
1456
  beam_search_encode_step_tab(
869
1457
  K,
870
1458
  n,
871
1459
  beam_size,
872
- codebook_cross_products.data() + codebook_offsets[m],
873
- total_codebook_size,
874
- codebook_offsets.data(),
875
- query_cp + codebook_offsets[m],
876
- total_codebook_size,
877
- cent_norms.data() + codebook_offsets[m],
1460
+ rq.codebook_cross_products.data() + rq.codebook_offsets[m],
1461
+ rq.total_codebook_size,
1462
+ rq.codebook_offsets.data(),
1463
+ query_cp + rq.codebook_offsets[m],
1464
+ rq.total_codebook_size,
1465
+ rq.cent_norms.data() + rq.codebook_offsets[m],
878
1466
  m,
879
- codes.data(),
880
- distances.data(),
1467
+ // codes.data(),
1468
+ codes_ptr,
1469
+ // distances.data(),
1470
+ distances_ptr,
881
1471
  new_beam_size,
882
- new_codes.data(),
883
- new_distances.data());
1472
+ // new_codes.data(),
1473
+ new_codes_ptr,
1474
+ // new_distances.data()
1475
+ new_distances_ptr,
1476
+ rq.approx_topk_mode);
1477
+
1478
+ // codes.swap(new_codes);
1479
+ std::swap(codes_ptr, new_codes_ptr);
1480
+ // distances.swap(new_distances);
1481
+ std::swap(distances_ptr, new_distances_ptr);
884
1482
 
885
- codes.swap(new_codes);
886
- distances.swap(new_distances);
887
1483
  beam_size = new_beam_size;
888
1484
 
889
- if (verbose) {
1485
+ if (rq.verbose) {
890
1486
  float sum_distances = 0;
891
- for (int j = 0; j < distances.size(); j++) {
892
- sum_distances += distances[j];
1487
+ // for (int j = 0; j < distances.size(); j++) {
1488
+ // sum_distances += distances[j];
1489
+ // }
1490
+ for (int j = 0; j < distances_size; j++) {
1491
+ sum_distances += distances_ptr[j];
893
1492
  }
894
1493
  printf("[%.3f s] encode stage %d, %d bits, "
895
1494
  "total error %g, beam_size %d\n",
896
1495
  (getmillisecs() - t0) / 1000,
897
1496
  m,
898
- int(nbits[m]),
1497
+ int(rq.nbits[m]),
899
1498
  sum_distances,
900
1499
  beam_size);
901
1500
  }
902
1501
  }
903
1502
 
904
1503
  if (out_codes) {
905
- memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
1504
+ // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
1505
+ memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
906
1506
  }
907
1507
  if (out_distances) {
1508
+ // memcpy(out_distances,
1509
+ // distances.data(),
1510
+ // distances.size() * sizeof(distances[0]));
908
1511
  memcpy(out_distances,
909
- distances.data(),
910
- distances.size() * sizeof(distances[0]));
1512
+ distances_ptr,
1513
+ distances_size * sizeof(*distances_ptr));
911
1514
  }
912
1515
  }
913
1516
 
1517
+ void ResidualQuantizer::refine_beam_LUT(
1518
+ size_t n,
1519
+ const float* query_norms, // size n
1520
+ const float* query_cp, //
1521
+ int out_beam_size,
1522
+ int32_t* out_codes,
1523
+ float* out_distances) const {
1524
+ RefineBeamLUTMemoryPool pool;
1525
+ refine_beam_LUT_mp(
1526
+ *this,
1527
+ n,
1528
+ query_norms,
1529
+ query_cp,
1530
+ out_beam_size,
1531
+ out_codes,
1532
+ out_distances,
1533
+ pool);
1534
+ }
1535
+
914
1536
  } // namespace faiss