faiss 0.2.6 → 0.2.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (189) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +4 -0
  3. data/ext/faiss/extconf.rb +1 -1
  4. data/lib/faiss/version.rb +1 -1
  5. data/lib/faiss.rb +2 -2
  6. data/vendor/faiss/faiss/AutoTune.cpp +15 -4
  7. data/vendor/faiss/faiss/AutoTune.h +0 -1
  8. data/vendor/faiss/faiss/Clustering.cpp +1 -5
  9. data/vendor/faiss/faiss/Clustering.h +0 -2
  10. data/vendor/faiss/faiss/IVFlib.h +0 -2
  11. data/vendor/faiss/faiss/Index.h +1 -2
  12. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +17 -3
  13. data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +10 -1
  14. data/vendor/faiss/faiss/IndexBinary.h +0 -1
  15. data/vendor/faiss/faiss/IndexBinaryFlat.cpp +2 -1
  16. data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -0
  17. data/vendor/faiss/faiss/IndexBinaryHash.cpp +1 -3
  18. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +273 -48
  19. data/vendor/faiss/faiss/IndexBinaryIVF.h +18 -11
  20. data/vendor/faiss/faiss/IndexFastScan.cpp +13 -10
  21. data/vendor/faiss/faiss/IndexFastScan.h +5 -1
  22. data/vendor/faiss/faiss/IndexFlat.cpp +16 -3
  23. data/vendor/faiss/faiss/IndexFlat.h +1 -1
  24. data/vendor/faiss/faiss/IndexFlatCodes.cpp +5 -0
  25. data/vendor/faiss/faiss/IndexFlatCodes.h +7 -2
  26. data/vendor/faiss/faiss/IndexHNSW.cpp +3 -6
  27. data/vendor/faiss/faiss/IndexHNSW.h +0 -1
  28. data/vendor/faiss/faiss/IndexIDMap.cpp +4 -4
  29. data/vendor/faiss/faiss/IndexIDMap.h +0 -2
  30. data/vendor/faiss/faiss/IndexIVF.cpp +155 -129
  31. data/vendor/faiss/faiss/IndexIVF.h +121 -61
  32. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +2 -2
  33. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +12 -11
  34. data/vendor/faiss/faiss/IndexIVFFastScan.h +6 -1
  35. data/vendor/faiss/faiss/IndexIVFPQ.cpp +221 -165
  36. data/vendor/faiss/faiss/IndexIVFPQ.h +1 -0
  37. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +6 -1
  38. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +0 -2
  39. data/vendor/faiss/faiss/IndexNNDescent.cpp +1 -2
  40. data/vendor/faiss/faiss/IndexNNDescent.h +0 -1
  41. data/vendor/faiss/faiss/IndexNSG.cpp +1 -2
  42. data/vendor/faiss/faiss/IndexPQ.cpp +7 -9
  43. data/vendor/faiss/faiss/IndexRefine.cpp +1 -1
  44. data/vendor/faiss/faiss/IndexReplicas.cpp +3 -4
  45. data/vendor/faiss/faiss/IndexReplicas.h +0 -1
  46. data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +8 -1
  47. data/vendor/faiss/faiss/IndexRowwiseMinMax.h +7 -0
  48. data/vendor/faiss/faiss/IndexShards.cpp +26 -109
  49. data/vendor/faiss/faiss/IndexShards.h +2 -3
  50. data/vendor/faiss/faiss/IndexShardsIVF.cpp +246 -0
  51. data/vendor/faiss/faiss/IndexShardsIVF.h +42 -0
  52. data/vendor/faiss/faiss/MetaIndexes.cpp +86 -0
  53. data/vendor/faiss/faiss/MetaIndexes.h +29 -0
  54. data/vendor/faiss/faiss/MetricType.h +14 -0
  55. data/vendor/faiss/faiss/VectorTransform.cpp +8 -10
  56. data/vendor/faiss/faiss/VectorTransform.h +1 -3
  57. data/vendor/faiss/faiss/clone_index.cpp +232 -18
  58. data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +25 -3
  59. data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +7 -0
  60. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +78 -0
  61. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +20 -6
  62. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +7 -1
  63. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +21 -7
  64. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +7 -0
  65. data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +7 -0
  66. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +10 -3
  67. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +7 -1
  68. data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +11 -3
  69. data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +25 -2
  70. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +76 -29
  71. data/vendor/faiss/faiss/gpu/GpuCloner.h +2 -2
  72. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +14 -13
  73. data/vendor/faiss/faiss/gpu/GpuDistance.h +18 -6
  74. data/vendor/faiss/faiss/gpu/GpuIndex.h +23 -21
  75. data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +10 -10
  76. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +11 -12
  77. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +29 -50
  78. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +3 -3
  79. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +8 -8
  80. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +4 -4
  81. data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +2 -5
  82. data/vendor/faiss/faiss/gpu/impl/RemapIndices.cpp +9 -7
  83. data/vendor/faiss/faiss/gpu/impl/RemapIndices.h +4 -4
  84. data/vendor/faiss/faiss/gpu/perf/IndexWrapper-inl.h +2 -2
  85. data/vendor/faiss/faiss/gpu/perf/IndexWrapper.h +1 -1
  86. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +55 -6
  87. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +20 -6
  88. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +95 -25
  89. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +67 -16
  90. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFScalarQuantizer.cpp +4 -4
  91. data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +7 -7
  92. data/vendor/faiss/faiss/gpu/test/TestUtils.h +4 -4
  93. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  95. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +0 -7
  96. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +9 -9
  97. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  98. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +2 -7
  99. data/vendor/faiss/faiss/impl/CodePacker.cpp +67 -0
  100. data/vendor/faiss/faiss/impl/CodePacker.h +71 -0
  101. data/vendor/faiss/faiss/impl/DistanceComputer.h +0 -2
  102. data/vendor/faiss/faiss/impl/HNSW.cpp +3 -7
  103. data/vendor/faiss/faiss/impl/HNSW.h +6 -9
  104. data/vendor/faiss/faiss/impl/IDSelector.cpp +1 -1
  105. data/vendor/faiss/faiss/impl/IDSelector.h +39 -1
  106. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +62 -51
  107. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +11 -12
  108. data/vendor/faiss/faiss/impl/NNDescent.cpp +3 -9
  109. data/vendor/faiss/faiss/impl/NNDescent.h +10 -10
  110. data/vendor/faiss/faiss/impl/NSG.cpp +1 -6
  111. data/vendor/faiss/faiss/impl/NSG.h +4 -7
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +1 -15
  113. data/vendor/faiss/faiss/impl/PolysemousTraining.h +11 -10
  114. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +0 -7
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -12
  116. data/vendor/faiss/faiss/impl/ProductQuantizer.h +2 -4
  117. data/vendor/faiss/faiss/impl/Quantizer.h +6 -3
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +796 -174
  119. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +16 -8
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +3 -5
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +4 -4
  122. data/vendor/faiss/faiss/impl/ThreadedIndex-inl.h +3 -3
  123. data/vendor/faiss/faiss/impl/ThreadedIndex.h +4 -4
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +291 -0
  125. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +74 -0
  126. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +123 -0
  127. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +102 -0
  128. data/vendor/faiss/faiss/impl/index_read.cpp +13 -10
  129. data/vendor/faiss/faiss/impl/index_write.cpp +3 -4
  130. data/vendor/faiss/faiss/impl/kmeans1d.cpp +0 -1
  131. data/vendor/faiss/faiss/impl/kmeans1d.h +3 -3
  132. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  133. data/vendor/faiss/faiss/impl/platform_macros.h +61 -0
  134. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +48 -4
  135. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +18 -4
  136. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +2 -2
  137. data/vendor/faiss/faiss/index_factory.cpp +8 -10
  138. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +29 -12
  139. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +8 -2
  140. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  141. data/vendor/faiss/faiss/invlists/DirectMap.h +2 -4
  142. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +118 -18
  143. data/vendor/faiss/faiss/invlists/InvertedLists.h +44 -4
  144. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  145. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +1 -1
  146. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  147. data/vendor/faiss/faiss/python/python_callbacks.h +1 -1
  148. data/vendor/faiss/faiss/utils/AlignedTable.h +3 -1
  149. data/vendor/faiss/faiss/utils/Heap.cpp +139 -3
  150. data/vendor/faiss/faiss/utils/Heap.h +35 -1
  151. data/vendor/faiss/faiss/utils/approx_topk/approx_topk.h +84 -0
  152. data/vendor/faiss/faiss/utils/approx_topk/avx2-inl.h +196 -0
  153. data/vendor/faiss/faiss/utils/approx_topk/generic.h +138 -0
  154. data/vendor/faiss/faiss/utils/approx_topk/mode.h +34 -0
  155. data/vendor/faiss/faiss/utils/approx_topk_hamming/approx_topk_hamming.h +367 -0
  156. data/vendor/faiss/faiss/utils/distances.cpp +61 -7
  157. data/vendor/faiss/faiss/utils/distances.h +11 -0
  158. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +346 -0
  159. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +36 -0
  160. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +42 -0
  161. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +40 -0
  162. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +352 -0
  163. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +32 -0
  164. data/vendor/faiss/faiss/utils/distances_simd.cpp +515 -327
  165. data/vendor/faiss/faiss/utils/extra_distances-inl.h +17 -1
  166. data/vendor/faiss/faiss/utils/extra_distances.cpp +37 -8
  167. data/vendor/faiss/faiss/utils/extra_distances.h +2 -1
  168. data/vendor/faiss/faiss/utils/fp16-fp16c.h +7 -0
  169. data/vendor/faiss/faiss/utils/fp16-inl.h +7 -0
  170. data/vendor/faiss/faiss/utils/fp16.h +7 -0
  171. data/vendor/faiss/faiss/utils/hamming-inl.h +0 -456
  172. data/vendor/faiss/faiss/utils/hamming.cpp +104 -120
  173. data/vendor/faiss/faiss/utils/hamming.h +21 -10
  174. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +535 -0
  175. data/vendor/faiss/faiss/utils/hamming_distance/common.h +48 -0
  176. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +519 -0
  177. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +26 -0
  178. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +614 -0
  179. data/vendor/faiss/faiss/utils/partitioning.cpp +21 -25
  180. data/vendor/faiss/faiss/utils/simdlib_avx2.h +344 -3
  181. data/vendor/faiss/faiss/utils/simdlib_emulated.h +390 -0
  182. data/vendor/faiss/faiss/utils/simdlib_neon.h +655 -130
  183. data/vendor/faiss/faiss/utils/sorting.cpp +692 -0
  184. data/vendor/faiss/faiss/utils/sorting.h +71 -0
  185. data/vendor/faiss/faiss/utils/transpose/transpose-avx2-inl.h +165 -0
  186. data/vendor/faiss/faiss/utils/utils.cpp +4 -176
  187. data/vendor/faiss/faiss/utils/utils.h +2 -9
  188. metadata +29 -3
  189. data/vendor/faiss/faiss/gpu/GpuClonerOptions.cpp +0 -26
@@ -23,6 +23,10 @@
23
23
  #include <faiss/utils/hamming.h>
24
24
  #include <faiss/utils/utils.h>
25
25
 
26
+ #include <faiss/utils/simdlib.h>
27
+
28
+ #include <faiss/utils/approx_topk/approx_topk.h>
29
+
26
30
  extern "C" {
27
31
 
28
32
  // general matrix multiplication
@@ -63,12 +67,7 @@ int sgelsd_(
63
67
 
64
68
  namespace faiss {
65
69
 
66
- ResidualQuantizer::ResidualQuantizer()
67
- : train_type(Train_progressive_dim),
68
- niter_codebook_refine(5),
69
- max_beam_size(5),
70
- use_beam_LUT(0),
71
- assign_index_factory(nullptr) {
70
+ ResidualQuantizer::ResidualQuantizer() {
72
71
  d = 0;
73
72
  M = 0;
74
73
  verbose = false;
@@ -139,12 +138,11 @@ void beam_search_encode_step(
139
138
  int32_t* new_codes, /// size (n, new_beam_size, m + 1)
140
139
  float* new_residuals, /// size (n, new_beam_size, d)
141
140
  float* new_distances, /// size (n, new_beam_size)
142
- Index* assign_index) {
141
+ Index* assign_index,
142
+ ApproxTopK_mode_t approx_topk_mode) {
143
143
  // we have to fill in the whole output matrix
144
144
  FAISS_THROW_IF_NOT(new_beam_size <= beam_size * K);
145
145
 
146
- using idx_t = Index::idx_t;
147
-
148
146
  std::vector<float> cent_distances;
149
147
  std::vector<idx_t> cent_ids;
150
148
 
@@ -230,15 +228,36 @@ void beam_search_encode_step(
230
228
  new_distances_i[i] = C::neutral();
231
229
  }
232
230
  std::vector<int> perm(new_beam_size, -1);
233
- heap_addn<C>(
234
- new_beam_size,
235
- new_distances_i,
236
- perm.data(),
237
- cent_distances_i,
238
- nullptr,
239
- beam_size * K);
231
+
232
+ #define HANDLE_APPROX(NB, BD) \
233
+ case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
234
+ HeapWithBuckets<C, NB, BD>::bs_addn( \
235
+ beam_size, \
236
+ K, \
237
+ cent_distances_i, \
238
+ new_beam_size, \
239
+ new_distances_i, \
240
+ perm.data()); \
241
+ break;
242
+
243
+ switch (approx_topk_mode) {
244
+ HANDLE_APPROX(8, 3)
245
+ HANDLE_APPROX(8, 2)
246
+ HANDLE_APPROX(16, 2)
247
+ HANDLE_APPROX(32, 2)
248
+ default:
249
+ heap_addn<C>(
250
+ new_beam_size,
251
+ new_distances_i,
252
+ perm.data(),
253
+ cent_distances_i,
254
+ nullptr,
255
+ beam_size * K);
256
+ }
240
257
  heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
241
258
 
259
+ #undef HANDLE_APPROX
260
+
242
261
  for (int j = 0; j < new_beam_size; j++) {
243
262
  int js = perm[j] / K;
244
263
  int ls = perm[j] % K;
@@ -364,7 +383,8 @@ void ResidualQuantizer::train(size_t n, const float* x) {
364
383
  new_codes.data() + i0 * new_beam_size * (m + 1),
365
384
  new_residuals.data() + i0 * new_beam_size * d,
366
385
  new_distances.data() + i0 * new_beam_size,
367
- assign_index.get());
386
+ assign_index.get(),
387
+ approx_topk_mode);
368
388
  }
369
389
  codes.swap(new_codes);
370
390
  residuals.swap(new_residuals);
@@ -544,11 +564,185 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
544
564
  size_t mem;
545
565
  mem = beam_size * d * 2 * sizeof(float); // size for 2 beams at a time
546
566
  mem += beam_size * beam_size *
547
- (sizeof(float) +
548
- sizeof(Index::idx_t)); // size for 1 beam search result
567
+ (sizeof(float) + sizeof(idx_t)); // size for 1 beam search result
549
568
  return mem;
550
569
  }
551
570
 
571
+ // a namespace full of preallocated buffers
572
+ namespace {
573
+
574
+ // Preallocated memory chunk for refine_beam_mp() call
575
+ struct RefineBeamMemoryPool {
576
+ std::vector<int32_t> new_codes;
577
+ std::vector<float> new_residuals;
578
+
579
+ std::vector<float> residuals;
580
+ std::vector<int32_t> codes;
581
+ std::vector<float> distances;
582
+ };
583
+
584
+ // Preallocated memory chunk for refine_beam_LUT_mp() call
585
+ struct RefineBeamLUTMemoryPool {
586
+ std::vector<int32_t> new_codes;
587
+ std::vector<float> new_distances;
588
+
589
+ std::vector<int32_t> codes;
590
+ std::vector<float> distances;
591
+ };
592
+
593
+ // this is for use_beam_LUT == 0 in compute_codes_add_centroids_mp_lut0() call
594
+ struct ComputeCodesAddCentroidsLUT0MemoryPool {
595
+ std::vector<int32_t> codes;
596
+ std::vector<float> norms;
597
+ std::vector<float> distances;
598
+ std::vector<float> residuals;
599
+ RefineBeamMemoryPool refine_beam_pool;
600
+ };
601
+
602
+ // this is for use_beam_LUT == 1 in compute_codes_add_centroids_mp_lut1() call
603
+ struct ComputeCodesAddCentroidsLUT1MemoryPool {
604
+ std::vector<int32_t> codes;
605
+ std::vector<float> distances;
606
+ std::vector<float> query_norms;
607
+ std::vector<float> query_cp;
608
+ std::vector<float> residuals;
609
+ RefineBeamLUTMemoryPool refine_beam_lut_pool;
610
+ };
611
+
612
+ } // namespace
613
+
614
+ // forward declaration
615
+ void refine_beam_mp(
616
+ const ResidualQuantizer& rq,
617
+ size_t n,
618
+ size_t beam_size,
619
+ const float* x,
620
+ int out_beam_size,
621
+ int32_t* out_codes,
622
+ float* out_residuals,
623
+ float* out_distances,
624
+ RefineBeamMemoryPool& pool);
625
+
626
+ // forward declaration
627
+ void refine_beam_LUT_mp(
628
+ const ResidualQuantizer& rq,
629
+ size_t n,
630
+ const float* query_norms, // size n
631
+ const float* query_cp, //
632
+ int out_beam_size,
633
+ int32_t* out_codes,
634
+ float* out_distances,
635
+ RefineBeamLUTMemoryPool& pool);
636
+
637
+ // this is for use_beam_LUT == 0
638
+ void compute_codes_add_centroids_mp_lut0(
639
+ const ResidualQuantizer& rq,
640
+ const float* x,
641
+ uint8_t* codes_out,
642
+ size_t n,
643
+ const float* centroids,
644
+ ComputeCodesAddCentroidsLUT0MemoryPool& pool) {
645
+ pool.codes.resize(rq.max_beam_size * rq.M * n);
646
+ pool.distances.resize(rq.max_beam_size * n);
647
+
648
+ pool.residuals.resize(rq.max_beam_size * n * rq.d);
649
+
650
+ refine_beam_mp(
651
+ rq,
652
+ n,
653
+ 1,
654
+ x,
655
+ rq.max_beam_size,
656
+ pool.codes.data(),
657
+ pool.residuals.data(),
658
+ pool.distances.data(),
659
+ pool.refine_beam_pool);
660
+
661
+ if (rq.search_type == ResidualQuantizer::ST_norm_float ||
662
+ rq.search_type == ResidualQuantizer::ST_norm_qint8 ||
663
+ rq.search_type == ResidualQuantizer::ST_norm_qint4) {
664
+ pool.norms.resize(n);
665
+ // recover the norms of reconstruction as
666
+ // || original_vector - residual ||^2
667
+ for (size_t i = 0; i < n; i++) {
668
+ pool.norms[i] = fvec_L2sqr(
669
+ x + i * rq.d,
670
+ pool.residuals.data() + i * rq.max_beam_size * rq.d,
671
+ rq.d);
672
+ }
673
+ }
674
+
675
+ // pack only the first code of the beam
676
+ // (hence the ld_codes=M * max_beam_size)
677
+ rq.pack_codes(
678
+ n,
679
+ pool.codes.data(),
680
+ codes_out,
681
+ rq.M * rq.max_beam_size,
682
+ (pool.norms.size() > 0) ? pool.norms.data() : nullptr,
683
+ centroids);
684
+ }
685
+
686
+ // use_beam_LUT == 1
687
+ void compute_codes_add_centroids_mp_lut1(
688
+ const ResidualQuantizer& rq,
689
+ const float* x,
690
+ uint8_t* codes_out,
691
+ size_t n,
692
+ const float* centroids,
693
+ ComputeCodesAddCentroidsLUT1MemoryPool& pool) {
694
+ //
695
+ pool.codes.resize(rq.max_beam_size * rq.M * n);
696
+ pool.distances.resize(rq.max_beam_size * n);
697
+
698
+ FAISS_THROW_IF_NOT_MSG(
699
+ rq.codebook_cross_products.size() ==
700
+ rq.total_codebook_size * rq.total_codebook_size,
701
+ "call compute_codebook_tables first");
702
+
703
+ pool.query_norms.resize(n);
704
+ fvec_norms_L2sqr(pool.query_norms.data(), x, rq.d, n);
705
+
706
+ pool.query_cp.resize(n * rq.total_codebook_size);
707
+ {
708
+ FINTEGER ti = rq.total_codebook_size, di = rq.d, ni = n;
709
+ float zero = 0, one = 1;
710
+ sgemm_("Transposed",
711
+ "Not transposed",
712
+ &ti,
713
+ &ni,
714
+ &di,
715
+ &one,
716
+ rq.codebooks.data(),
717
+ &di,
718
+ x,
719
+ &di,
720
+ &zero,
721
+ pool.query_cp.data(),
722
+ &ti);
723
+ }
724
+
725
+ refine_beam_LUT_mp(
726
+ rq,
727
+ n,
728
+ pool.query_norms.data(),
729
+ pool.query_cp.data(),
730
+ rq.max_beam_size,
731
+ pool.codes.data(),
732
+ pool.distances.data(),
733
+ pool.refine_beam_lut_pool);
734
+
735
+ // pack only the first code of the beam
736
+ // (hence the ld_codes=M * max_beam_size)
737
+ rq.pack_codes(
738
+ n,
739
+ pool.codes.data(),
740
+ codes_out,
741
+ rq.M * rq.max_beam_size,
742
+ nullptr,
743
+ centroids);
744
+ }
745
+
552
746
  void ResidualQuantizer::compute_codes_add_centroids(
553
747
  const float* x,
554
748
  uint8_t* codes_out,
@@ -556,184 +750,212 @@ void ResidualQuantizer::compute_codes_add_centroids(
556
750
  const float* centroids) const {
557
751
  FAISS_THROW_IF_NOT_MSG(is_trained, "RQ is not trained yet.");
558
752
 
753
+ //
559
754
  size_t mem = memory_per_point();
560
- if (n > 1 && mem * n > max_mem_distances) {
561
- // then split queries to reduce temp memory
562
- size_t bs = max_mem_distances / mem;
563
- if (bs == 0) {
564
- bs = 1; // otherwise we can't do much
565
- }
566
- for (size_t i0 = 0; i0 < n; i0 += bs) {
567
- size_t i1 = std::min(n, i0 + bs);
568
- const float* cent = nullptr;
569
- if (centroids != nullptr) {
570
- cent = centroids + i0 * d;
571
- }
572
- compute_codes_add_centroids(
573
- x + i0 * d, codes_out + i0 * code_size, i1 - i0, cent);
574
- }
575
- return;
755
+
756
+ size_t bs = max_mem_distances / mem;
757
+ if (bs == 0) {
758
+ bs = 1; // otherwise we can't do much
576
759
  }
577
760
 
578
- std::vector<int32_t> codes(max_beam_size * M * n);
579
- std::vector<float> norms;
580
- std::vector<float> distances(max_beam_size * n);
761
+ // prepare memory pools
762
+ ComputeCodesAddCentroidsLUT0MemoryPool pool0;
763
+ ComputeCodesAddCentroidsLUT1MemoryPool pool1;
581
764
 
582
- if (use_beam_LUT == 0) {
583
- std::vector<float> residuals(max_beam_size * n * d);
765
+ for (size_t i0 = 0; i0 < n; i0 += bs) {
766
+ size_t i1 = std::min(n, i0 + bs);
767
+ const float* cent = nullptr;
768
+ if (centroids != nullptr) {
769
+ cent = centroids + i0 * d;
770
+ }
584
771
 
585
- refine_beam(
586
- n,
587
- 1,
588
- x,
589
- max_beam_size,
590
- codes.data(),
591
- residuals.data(),
592
- distances.data());
593
-
594
- if (search_type == ST_norm_float || search_type == ST_norm_qint8 ||
595
- search_type == ST_norm_qint4) {
596
- norms.resize(n);
597
- // recover the norms of reconstruction as
598
- // || original_vector - residual ||^2
599
- for (size_t i = 0; i < n; i++) {
600
- norms[i] = fvec_L2sqr(
601
- x + i * d, residuals.data() + i * max_beam_size * d, d);
602
- }
772
+ // compute_codes_add_centroids(
773
+ // x + i0 * d,
774
+ // codes_out + i0 * code_size,
775
+ // i1 - i0,
776
+ // cent);
777
+ if (use_beam_LUT == 0) {
778
+ compute_codes_add_centroids_mp_lut0(
779
+ *this,
780
+ x + i0 * d,
781
+ codes_out + i0 * code_size,
782
+ i1 - i0,
783
+ cent,
784
+ pool0);
785
+ } else if (use_beam_LUT == 1) {
786
+ compute_codes_add_centroids_mp_lut1(
787
+ *this,
788
+ x + i0 * d,
789
+ codes_out + i0 * code_size,
790
+ i1 - i0,
791
+ cent,
792
+ pool1);
603
793
  }
604
- } else if (use_beam_LUT == 1) {
605
- FAISS_THROW_IF_NOT_MSG(
606
- codebook_cross_products.size() ==
607
- total_codebook_size * total_codebook_size,
608
- "call compute_codebook_tables first");
609
-
610
- std::vector<float> query_norms(n);
611
- fvec_norms_L2sqr(query_norms.data(), x, d, n);
612
-
613
- std::vector<float> query_cp(n * total_codebook_size);
614
- {
615
- FINTEGER ti = total_codebook_size, di = d, ni = n;
616
- float zero = 0, one = 1;
617
- sgemm_("Transposed",
618
- "Not transposed",
619
- &ti,
620
- &ni,
621
- &di,
622
- &one,
623
- codebooks.data(),
624
- &di,
625
- x,
626
- &di,
627
- &zero,
628
- query_cp.data(),
629
- &ti);
630
- }
631
-
632
- refine_beam_LUT(
633
- n,
634
- query_norms.data(),
635
- query_cp.data(),
636
- max_beam_size,
637
- codes.data(),
638
- distances.data());
639
- }
640
- // pack only the first code of the beam (hence the ld_codes=M *
641
- // max_beam_size)
642
- pack_codes(
643
- n,
644
- codes.data(),
645
- codes_out,
646
- M * max_beam_size,
647
- norms.size() > 0 ? norms.data() : nullptr,
648
- centroids);
794
+ }
649
795
  }
650
796
 
651
- void ResidualQuantizer::refine_beam(
797
+ void refine_beam_mp(
798
+ const ResidualQuantizer& rq,
652
799
  size_t n,
653
800
  size_t beam_size,
654
801
  const float* x,
655
802
  int out_beam_size,
656
803
  int32_t* out_codes,
657
804
  float* out_residuals,
658
- float* out_distances) const {
805
+ float* out_distances,
806
+ RefineBeamMemoryPool& pool) {
659
807
  int cur_beam_size = beam_size;
660
808
 
661
- std::vector<float> residuals(x, x + n * d * beam_size);
662
- std::vector<int32_t> codes;
663
- std::vector<float> distances;
664
809
  double t0 = getmillisecs();
665
810
 
811
+ // find the max_beam_size
812
+ int max_beam_size = 0;
813
+ {
814
+ int tmp_beam_size = cur_beam_size;
815
+ for (int m = 0; m < rq.M; m++) {
816
+ int K = 1 << rq.nbits[m];
817
+ int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
818
+ tmp_beam_size = new_beam_size;
819
+
820
+ if (max_beam_size < new_beam_size) {
821
+ max_beam_size = new_beam_size;
822
+ }
823
+ }
824
+ }
825
+
826
+ // preallocate buffers
827
+ pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
828
+ pool.new_residuals.resize(n * max_beam_size * rq.d);
829
+
830
+ pool.codes.resize(n * max_beam_size * (rq.M + 1));
831
+ pool.distances.resize(n * max_beam_size);
832
+ pool.residuals.resize(n * rq.d * max_beam_size);
833
+
834
+ for (size_t i = 0; i < n * rq.d * beam_size; i++) {
835
+ pool.residuals[i] = x[i];
836
+ }
837
+
838
+ // set up pointers to buffers
839
+ int32_t* __restrict codes_ptr = pool.codes.data();
840
+ float* __restrict residuals_ptr = pool.residuals.data();
841
+
842
+ int32_t* __restrict new_codes_ptr = pool.new_codes.data();
843
+ float* __restrict new_residuals_ptr = pool.new_residuals.data();
844
+
845
+ // index
666
846
  std::unique_ptr<Index> assign_index;
667
- if (assign_index_factory) {
668
- assign_index.reset((*assign_index_factory)(d));
847
+ if (rq.assign_index_factory) {
848
+ assign_index.reset((*rq.assign_index_factory)(rq.d));
669
849
  } else {
670
- assign_index.reset(new IndexFlatL2(d));
850
+ assign_index.reset(new IndexFlatL2(rq.d));
671
851
  }
672
852
 
673
- for (int m = 0; m < M; m++) {
674
- int K = 1 << nbits[m];
853
+ // main loop
854
+ size_t codes_size = 0;
855
+ size_t distances_size = 0;
856
+ size_t residuals_size = 0;
675
857
 
676
- const float* codebooks_m =
677
- this->codebooks.data() + codebook_offsets[m] * d;
858
+ for (int m = 0; m < rq.M; m++) {
859
+ int K = 1 << rq.nbits[m];
678
860
 
679
- int new_beam_size = std::min(cur_beam_size * K, out_beam_size);
861
+ const float* __restrict codebooks_m =
862
+ rq.codebooks.data() + rq.codebook_offsets[m] * rq.d;
680
863
 
681
- std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
682
- std::vector<float> new_residuals(n * new_beam_size * d);
683
- distances.resize(n * new_beam_size);
864
+ const int new_beam_size = std::min(cur_beam_size * K, out_beam_size);
865
+
866
+ codes_size = n * new_beam_size * (m + 1);
867
+ residuals_size = n * new_beam_size * rq.d;
868
+ distances_size = n * new_beam_size;
684
869
 
685
870
  beam_search_encode_step(
686
- d,
871
+ rq.d,
687
872
  K,
688
873
  codebooks_m,
689
874
  n,
690
875
  cur_beam_size,
691
- residuals.data(),
876
+ // residuals.data(),
877
+ residuals_ptr,
692
878
  m,
693
- codes.data(),
879
+ // codes.data(),
880
+ codes_ptr,
694
881
  new_beam_size,
695
- new_codes.data(),
696
- new_residuals.data(),
697
- distances.data(),
698
- assign_index.get());
882
+ // new_codes.data(),
883
+ new_codes_ptr,
884
+ // new_residuals.data(),
885
+ new_residuals_ptr,
886
+ pool.distances.data(),
887
+ assign_index.get(),
888
+ rq.approx_topk_mode);
699
889
 
700
890
  assign_index->reset();
701
891
 
702
- codes.swap(new_codes);
703
- residuals.swap(new_residuals);
892
+ std::swap(codes_ptr, new_codes_ptr);
893
+ std::swap(residuals_ptr, new_residuals_ptr);
704
894
 
705
895
  cur_beam_size = new_beam_size;
706
896
 
707
- if (verbose) {
897
+ if (rq.verbose) {
708
898
  float sum_distances = 0;
709
- for (int j = 0; j < distances.size(); j++) {
710
- sum_distances += distances[j];
899
+ // for (int j = 0; j < distances.size(); j++) {
900
+ // sum_distances += distances[j];
901
+ // }
902
+ for (int j = 0; j < distances_size; j++) {
903
+ sum_distances += pool.distances[j];
711
904
  }
905
+
712
906
  printf("[%.3f s] encode stage %d, %d bits, "
713
907
  "total error %g, beam_size %d\n",
714
908
  (getmillisecs() - t0) / 1000,
715
909
  m,
716
- int(nbits[m]),
910
+ int(rq.nbits[m]),
717
911
  sum_distances,
718
912
  cur_beam_size);
719
913
  }
720
914
  }
721
915
 
722
916
  if (out_codes) {
723
- memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
917
+ // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
918
+ memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
724
919
  }
725
920
  if (out_residuals) {
921
+ // memcpy(out_residuals,
922
+ // residuals.data(),
923
+ // residuals.size() * sizeof(residuals[0]));
726
924
  memcpy(out_residuals,
727
- residuals.data(),
728
- residuals.size() * sizeof(residuals[0]));
925
+ residuals_ptr,
926
+ residuals_size * sizeof(*residuals_ptr));
729
927
  }
730
928
  if (out_distances) {
929
+ // memcpy(out_distances,
930
+ // distances.data(),
931
+ // distances.size() * sizeof(distances[0]));
731
932
  memcpy(out_distances,
732
- distances.data(),
733
- distances.size() * sizeof(distances[0]));
933
+ pool.distances.data(),
934
+ distances_size * sizeof(pool.distances[0]));
734
935
  }
735
936
  }
736
937
 
938
+ void ResidualQuantizer::refine_beam(
939
+ size_t n,
940
+ size_t beam_size,
941
+ const float* x,
942
+ int out_beam_size,
943
+ int32_t* out_codes,
944
+ float* out_residuals,
945
+ float* out_distances) const {
946
+ RefineBeamMemoryPool pool;
947
+ refine_beam_mp(
948
+ *this,
949
+ n,
950
+ beam_size,
951
+ x,
952
+ out_beam_size,
953
+ out_codes,
954
+ out_residuals,
955
+ out_distances,
956
+ pool);
957
+ }
958
+
737
959
  /*******************************************************************
738
960
  * Functions using the dot products between codebook entries
739
961
  *******************************************************************/
@@ -765,6 +987,186 @@ void ResidualQuantizer::compute_codebook_tables() {
765
987
  }
766
988
  }
767
989
 
990
+ namespace {
991
+
992
+ template <size_t M, size_t NK>
993
+ void accum_and_store_tab(
994
+ const size_t m_offset,
995
+ const float* const __restrict codebook_cross_norms,
996
+ const uint64_t* const __restrict codebook_offsets,
997
+ const int32_t* const __restrict codes_i,
998
+ const size_t b,
999
+ const size_t ldc,
1000
+ const size_t K,
1001
+ float* const __restrict output) {
1002
+ // load pointers into registers
1003
+ const float* cbs[M];
1004
+ for (size_t ij = 0; ij < M; ij++) {
1005
+ const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
1006
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1007
+ }
1008
+
1009
+ // do accumulation in registers using SIMD.
1010
+ // It is possible that compiler may be smart enough so that
1011
+ // this manual SIMD unrolling might be unneeded.
1012
+ #if defined(__AVX2__) || defined(__aarch64__)
1013
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
1014
+
1015
+ // process in chunks of size (8 * NK) floats
1016
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1017
+ simd8float32 regs[NK];
1018
+ for (size_t ik = 0; ik < NK; ik++) {
1019
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
1020
+ }
1021
+
1022
+ for (size_t ij = 1; ij < M; ij++) {
1023
+ for (size_t ik = 0; ik < NK; ik++) {
1024
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1025
+ }
1026
+ }
1027
+
1028
+ // write the result
1029
+ for (size_t ik = 0; ik < NK; ik++) {
1030
+ regs[ik].storeu(output + kk + ik * 8);
1031
+ }
1032
+ }
1033
+ #else
1034
+ const size_t K8 = 0;
1035
+ #endif
1036
+
1037
+ // process leftovers
1038
+ for (size_t kk = K8; kk < K; kk++) {
1039
+ float reg = cbs[0][kk];
1040
+ for (size_t ij = 1; ij < M; ij++) {
1041
+ reg += cbs[ij][kk];
1042
+ }
1043
+ output[b * K + kk] = reg;
1044
+ }
1045
+ }
1046
+
1047
+ template <size_t M, size_t NK>
1048
+ void accum_and_add_tab(
1049
+ const size_t m_offset,
1050
+ const float* const __restrict codebook_cross_norms,
1051
+ const uint64_t* const __restrict codebook_offsets,
1052
+ const int32_t* const __restrict codes_i,
1053
+ const size_t b,
1054
+ const size_t ldc,
1055
+ const size_t K,
1056
+ float* const __restrict output) {
1057
+ // load pointers into registers
1058
+ const float* cbs[M];
1059
+ for (size_t ij = 0; ij < M; ij++) {
1060
+ const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
1061
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1062
+ }
1063
+
1064
+ // do accumulation in registers using SIMD.
1065
+ // It is possible that compiler may be smart enough so that
1066
+ // this manual SIMD unrolling might be unneeded.
1067
+ #if defined(__AVX2__) || defined(__aarch64__)
1068
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
1069
+
1070
+ // process in chunks of size (8 * NK) floats
1071
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1072
+ simd8float32 regs[NK];
1073
+ for (size_t ik = 0; ik < NK; ik++) {
1074
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
1075
+ }
1076
+
1077
+ for (size_t ij = 1; ij < M; ij++) {
1078
+ for (size_t ik = 0; ik < NK; ik++) {
1079
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1080
+ }
1081
+ }
1082
+
1083
+ // write the result
1084
+ for (size_t ik = 0; ik < NK; ik++) {
1085
+ simd8float32 existing(output + kk + ik * 8);
1086
+ existing += regs[ik];
1087
+ existing.storeu(output + kk + ik * 8);
1088
+ }
1089
+ }
1090
+ #else
1091
+ const size_t K8 = 0;
1092
+ #endif
1093
+
1094
+ // process leftovers
1095
+ for (size_t kk = K8; kk < K; kk++) {
1096
+ float reg = cbs[0][kk];
1097
+ for (size_t ij = 1; ij < M; ij++) {
1098
+ reg += cbs[ij][kk];
1099
+ }
1100
+ output[b * K + kk] += reg;
1101
+ }
1102
+ }
1103
+
1104
+ template <size_t M, size_t NK>
1105
+ void accum_and_finalize_tab(
1106
+ const float* const __restrict codebook_cross_norms,
1107
+ const uint64_t* const __restrict codebook_offsets,
1108
+ const int32_t* const __restrict codes_i,
1109
+ const size_t b,
1110
+ const size_t ldc,
1111
+ const size_t K,
1112
+ const float* const __restrict distances_i,
1113
+ const float* const __restrict cd_common,
1114
+ float* const __restrict output) {
1115
+ // load pointers into registers
1116
+ const float* cbs[M];
1117
+ for (size_t ij = 0; ij < M; ij++) {
1118
+ const size_t code = static_cast<size_t>(codes_i[b * M + ij]);
1119
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1120
+ }
1121
+
1122
+ // do accumulation in registers using SIMD.
1123
+ // It is possible that compiler may be smart enough so that
1124
+ // this manual SIMD unrolling might be unneeded.
1125
+ #if defined(__AVX2__) || defined(__aarch64__)
1126
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
1127
+
1128
+ // process in chunks of size (8 * NK) floats
1129
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1130
+ simd8float32 regs[NK];
1131
+ for (size_t ik = 0; ik < NK; ik++) {
1132
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
1133
+ }
1134
+
1135
+ for (size_t ij = 1; ij < M; ij++) {
1136
+ for (size_t ik = 0; ik < NK; ik++) {
1137
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1138
+ }
1139
+ }
1140
+
1141
+ simd8float32 two(2.0f);
1142
+ for (size_t ik = 0; ik < NK; ik++) {
1143
+ // cent_distances[b * K + k] = distances_i[b] + cd_common[k]
1144
+ // + 2 * dp[k];
1145
+
1146
+ simd8float32 common_v(cd_common + kk + ik * 8);
1147
+ common_v = fmadd(two, regs[ik], common_v);
1148
+
1149
+ common_v += simd8float32(distances_i[b]);
1150
+ common_v.storeu(output + b * K + kk + ik * 8);
1151
+ }
1152
+ }
1153
+ #else
1154
+ const size_t K8 = 0;
1155
+ #endif
1156
+
1157
+ // process leftovers
1158
+ for (size_t kk = K8; kk < K; kk++) {
1159
+ float reg = cbs[0][kk];
1160
+ for (size_t ij = 1; ij < M; ij++) {
1161
+ reg += cbs[ij][kk];
1162
+ }
1163
+
1164
+ output[b * K + kk] = distances_i[b] + cd_common[kk] + 2 * reg;
1165
+ }
1166
+ }
1167
+
1168
+ } // namespace
1169
+
768
1170
  void beam_search_encode_step_tab(
769
1171
  size_t K,
770
1172
  size_t n,
@@ -779,12 +1181,13 @@ void beam_search_encode_step_tab(
779
1181
  const int32_t* codes, // n * beam_size * m
780
1182
  const float* distances, // n * beam_size
781
1183
  size_t new_beam_size,
782
- int32_t* new_codes, // n * new_beam_size * (m + 1)
783
- float* new_distances) // n * new_beam_size
1184
+ int32_t* new_codes, // n * new_beam_size * (m + 1)
1185
+ float* new_distances, // n * new_beam_size
1186
+ ApproxTopK_mode_t approx_topk_mode) //
784
1187
  {
785
1188
  FAISS_THROW_IF_NOT(ldc >= K);
786
1189
 
787
- #pragma omp parallel for if (n > 100)
1190
+ #pragma omp parallel for if (n > 100) schedule(dynamic)
788
1191
  for (int64_t i = 0; i < n; i++) {
789
1192
  std::vector<float> cent_distances(beam_size * K);
790
1193
  std::vector<float> cd_common(K);
@@ -797,6 +1200,14 @@ void beam_search_encode_step_tab(
797
1200
  cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
798
1201
  }
799
1202
 
1203
+ /*
1204
+ // This is the baseline implementation. Its primary flaw
1205
+ // that it writes way too many info to the temporary buffer
1206
+ // called dp.
1207
+ //
1208
+ // This baseline code is kept intentionally because it is easy to
1209
+ // understand what an optimized version optimizes exactly.
1210
+ //
800
1211
  for (size_t b = 0; b < beam_size; b++) {
801
1212
  std::vector<float> dp(K);
802
1213
 
@@ -812,6 +1223,117 @@ void beam_search_encode_step_tab(
812
1223
  distances_i[b] + cd_common[k] + 2 * dp[k];
813
1224
  }
814
1225
  }
1226
+ */
1227
+
1228
+ // An optimized implementation that avoids using a temporary buffer
1229
+ // and does the accumulation in registers.
1230
+
1231
+ // Compute a sum of NK AQ codes.
1232
+ #define ACCUM_AND_FINALIZE_TAB(NK) \
1233
+ case NK: \
1234
+ for (size_t b = 0; b < beam_size; b++) { \
1235
+ accum_and_finalize_tab<NK, 4>( \
1236
+ codebook_cross_norms, \
1237
+ codebook_offsets, \
1238
+ codes_i, \
1239
+ b, \
1240
+ ldc, \
1241
+ K, \
1242
+ distances_i, \
1243
+ cd_common.data(), \
1244
+ cent_distances.data()); \
1245
+ } \
1246
+ break;
1247
+
1248
+ // this version contains many switch-case scenarios, but
1249
+ // they won't affect branch predictor.
1250
+ switch (m) {
1251
+ case 0:
1252
+ // trivial case
1253
+ for (size_t b = 0; b < beam_size; b++) {
1254
+ for (size_t k = 0; k < K; k++) {
1255
+ cent_distances[b * K + k] =
1256
+ distances_i[b] + cd_common[k];
1257
+ }
1258
+ }
1259
+ break;
1260
+
1261
+ ACCUM_AND_FINALIZE_TAB(1)
1262
+ ACCUM_AND_FINALIZE_TAB(2)
1263
+ ACCUM_AND_FINALIZE_TAB(3)
1264
+ ACCUM_AND_FINALIZE_TAB(4)
1265
+ ACCUM_AND_FINALIZE_TAB(5)
1266
+ ACCUM_AND_FINALIZE_TAB(6)
1267
+ ACCUM_AND_FINALIZE_TAB(7)
1268
+
1269
+ default: {
1270
+ // m >= 8 case.
1271
+
1272
+ // A temporary buffer has to be used due to the lack of
1273
+ // registers. But we'll try to accumulate up to 8 AQ codes in
1274
+ // registers and issue a single write operation to the buffer,
1275
+ // while the baseline does no accumulation. So, the number of
1276
+ // write operations to the temporary buffer is reduced 8x.
1277
+
1278
+ // allocate a temporary buffer
1279
+ std::vector<float> dp(K);
1280
+
1281
+ for (size_t b = 0; b < beam_size; b++) {
1282
+ // Initialize it. Compute a sum of first 8 AQ codes
1283
+ // because m >= 8 .
1284
+ accum_and_store_tab<8, 4>(
1285
+ m,
1286
+ codebook_cross_norms,
1287
+ codebook_offsets,
1288
+ codes_i,
1289
+ b,
1290
+ ldc,
1291
+ K,
1292
+ dp.data());
1293
+
1294
+ #define ACCUM_AND_ADD_TAB(NK) \
1295
+ case NK: \
1296
+ accum_and_add_tab<NK, 4>( \
1297
+ m, \
1298
+ codebook_cross_norms, \
1299
+ codebook_offsets + im, \
1300
+ codes_i + im, \
1301
+ b, \
1302
+ ldc, \
1303
+ K, \
1304
+ dp.data()); \
1305
+ break;
1306
+
1307
+ // accumulate up to 8 additional AQ codes into
1308
+ // a temporary buffer
1309
+ for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
1310
+ size_t m_left = m - im;
1311
+ if (m_left > 8) {
1312
+ m_left = 8;
1313
+ }
1314
+
1315
+ switch (m_left) {
1316
+ ACCUM_AND_ADD_TAB(1)
1317
+ ACCUM_AND_ADD_TAB(2)
1318
+ ACCUM_AND_ADD_TAB(3)
1319
+ ACCUM_AND_ADD_TAB(4)
1320
+ ACCUM_AND_ADD_TAB(5)
1321
+ ACCUM_AND_ADD_TAB(6)
1322
+ ACCUM_AND_ADD_TAB(7)
1323
+ ACCUM_AND_ADD_TAB(8)
1324
+ }
1325
+ }
1326
+
1327
+ // done. finalize the result
1328
+ for (size_t k = 0; k < K; k++) {
1329
+ cent_distances[b * K + k] =
1330
+ distances_i[b] + cd_common[k] + 2 * dp[k];
1331
+ }
1332
+ }
1333
+ }
1334
+ }
1335
+
1336
+ // the optimized implementation ends here
815
1337
 
816
1338
  using C = CMax<float, int>;
817
1339
  int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
@@ -824,15 +1346,38 @@ void beam_search_encode_step_tab(
824
1346
  new_distances_i[i] = C::neutral();
825
1347
  }
826
1348
  std::vector<int> perm(new_beam_size, -1);
827
- heap_addn<C>(
828
- new_beam_size,
829
- new_distances_i,
830
- perm.data(),
831
- cent_distances_i,
832
- nullptr,
833
- beam_size * K);
1349
+
1350
+ #define HANDLE_APPROX(NB, BD) \
1351
+ case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
1352
+ HeapWithBuckets<C, NB, BD>::bs_addn( \
1353
+ beam_size, \
1354
+ K, \
1355
+ cent_distances_i, \
1356
+ new_beam_size, \
1357
+ new_distances_i, \
1358
+ perm.data()); \
1359
+ break;
1360
+
1361
+ switch (approx_topk_mode) {
1362
+ HANDLE_APPROX(8, 3)
1363
+ HANDLE_APPROX(8, 2)
1364
+ HANDLE_APPROX(16, 2)
1365
+ HANDLE_APPROX(32, 2)
1366
+ default:
1367
+ heap_addn<C>(
1368
+ new_beam_size,
1369
+ new_distances_i,
1370
+ perm.data(),
1371
+ cent_distances_i,
1372
+ nullptr,
1373
+ beam_size * K);
1374
+ break;
1375
+ }
1376
+
834
1377
  heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
835
1378
 
1379
+ #undef HANDLE_APPROX
1380
+
836
1381
  for (int j = 0; j < new_beam_size; j++) {
837
1382
  int js = perm[j] / K;
838
1383
  int ls = perm[j] % K;
@@ -845,70 +1390,147 @@ void beam_search_encode_step_tab(
845
1390
  }
846
1391
  }
847
1392
 
848
- void ResidualQuantizer::refine_beam_LUT(
1393
+ //
1394
+ void refine_beam_LUT_mp(
1395
+ const ResidualQuantizer& rq,
849
1396
  size_t n,
850
1397
  const float* query_norms, // size n
851
1398
  const float* query_cp, //
852
1399
  int out_beam_size,
853
1400
  int32_t* out_codes,
854
- float* out_distances) const {
1401
+ float* out_distances,
1402
+ RefineBeamLUTMemoryPool& pool) {
855
1403
  int beam_size = 1;
856
1404
 
857
- std::vector<int32_t> codes;
858
- std::vector<float> distances(query_norms, query_norms + n);
859
1405
  double t0 = getmillisecs();
860
1406
 
861
- for (int m = 0; m < M; m++) {
862
- int K = 1 << nbits[m];
1407
+ // find the max_beam_size
1408
+ int max_beam_size = 0;
1409
+ {
1410
+ int tmp_beam_size = beam_size;
1411
+ for (int m = 0; m < rq.M; m++) {
1412
+ int K = 1 << rq.nbits[m];
1413
+ int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
1414
+ tmp_beam_size = new_beam_size;
1415
+
1416
+ if (max_beam_size < new_beam_size) {
1417
+ max_beam_size = new_beam_size;
1418
+ }
1419
+ }
1420
+ }
863
1421
 
1422
+ // preallocate buffers
1423
+ pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
1424
+ pool.new_distances.resize(n * max_beam_size);
1425
+
1426
+ pool.codes.resize(n * max_beam_size * (rq.M + 1));
1427
+ pool.distances.resize(n * max_beam_size);
1428
+
1429
+ for (size_t i = 0; i < n; i++) {
1430
+ pool.distances[i] = query_norms[i];
1431
+ }
1432
+
1433
+ // set up pointers to buffers
1434
+ int32_t* __restrict new_codes_ptr = pool.new_codes.data();
1435
+ float* __restrict new_distances_ptr = pool.new_distances.data();
1436
+
1437
+ int32_t* __restrict codes_ptr = pool.codes.data();
1438
+ float* __restrict distances_ptr = pool.distances.data();
1439
+
1440
+ // main loop
1441
+ size_t codes_size = 0;
1442
+ size_t distances_size = 0;
1443
+ for (int m = 0; m < rq.M; m++) {
1444
+ int K = 1 << rq.nbits[m];
1445
+
1446
+ // it is guaranteed that (new_beam_size <= than max_beam_size) ==
1447
+ // true
864
1448
  int new_beam_size = std::min(beam_size * K, out_beam_size);
865
- std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
866
- std::vector<float> new_distances(n * new_beam_size);
1449
+
1450
+ // std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
1451
+ // std::vector<float> new_distances(n * new_beam_size);
1452
+
1453
+ codes_size = n * new_beam_size * (m + 1);
1454
+ distances_size = n * new_beam_size;
867
1455
 
868
1456
  beam_search_encode_step_tab(
869
1457
  K,
870
1458
  n,
871
1459
  beam_size,
872
- codebook_cross_products.data() + codebook_offsets[m],
873
- total_codebook_size,
874
- codebook_offsets.data(),
875
- query_cp + codebook_offsets[m],
876
- total_codebook_size,
877
- cent_norms.data() + codebook_offsets[m],
1460
+ rq.codebook_cross_products.data() + rq.codebook_offsets[m],
1461
+ rq.total_codebook_size,
1462
+ rq.codebook_offsets.data(),
1463
+ query_cp + rq.codebook_offsets[m],
1464
+ rq.total_codebook_size,
1465
+ rq.cent_norms.data() + rq.codebook_offsets[m],
878
1466
  m,
879
- codes.data(),
880
- distances.data(),
1467
+ // codes.data(),
1468
+ codes_ptr,
1469
+ // distances.data(),
1470
+ distances_ptr,
881
1471
  new_beam_size,
882
- new_codes.data(),
883
- new_distances.data());
1472
+ // new_codes.data(),
1473
+ new_codes_ptr,
1474
+ // new_distances.data()
1475
+ new_distances_ptr,
1476
+ rq.approx_topk_mode);
1477
+
1478
+ // codes.swap(new_codes);
1479
+ std::swap(codes_ptr, new_codes_ptr);
1480
+ // distances.swap(new_distances);
1481
+ std::swap(distances_ptr, new_distances_ptr);
884
1482
 
885
- codes.swap(new_codes);
886
- distances.swap(new_distances);
887
1483
  beam_size = new_beam_size;
888
1484
 
889
- if (verbose) {
1485
+ if (rq.verbose) {
890
1486
  float sum_distances = 0;
891
- for (int j = 0; j < distances.size(); j++) {
892
- sum_distances += distances[j];
1487
+ // for (int j = 0; j < distances.size(); j++) {
1488
+ // sum_distances += distances[j];
1489
+ // }
1490
+ for (int j = 0; j < distances_size; j++) {
1491
+ sum_distances += distances_ptr[j];
893
1492
  }
894
1493
  printf("[%.3f s] encode stage %d, %d bits, "
895
1494
  "total error %g, beam_size %d\n",
896
1495
  (getmillisecs() - t0) / 1000,
897
1496
  m,
898
- int(nbits[m]),
1497
+ int(rq.nbits[m]),
899
1498
  sum_distances,
900
1499
  beam_size);
901
1500
  }
902
1501
  }
903
1502
 
904
1503
  if (out_codes) {
905
- memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
1504
+ // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
1505
+ memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
906
1506
  }
907
1507
  if (out_distances) {
1508
+ // memcpy(out_distances,
1509
+ // distances.data(),
1510
+ // distances.size() * sizeof(distances[0]));
908
1511
  memcpy(out_distances,
909
- distances.data(),
910
- distances.size() * sizeof(distances[0]));
1512
+ distances_ptr,
1513
+ distances_size * sizeof(*distances_ptr));
911
1514
  }
912
1515
  }
913
1516
 
1517
+ void ResidualQuantizer::refine_beam_LUT(
1518
+ size_t n,
1519
+ const float* query_norms, // size n
1520
+ const float* query_cp, //
1521
+ int out_beam_size,
1522
+ int32_t* out_codes,
1523
+ float* out_distances) const {
1524
+ RefineBeamLUTMemoryPool pool;
1525
+ refine_beam_LUT_mp(
1526
+ *this,
1527
+ n,
1528
+ query_norms,
1529
+ query_cp,
1530
+ out_beam_size,
1531
+ out_codes,
1532
+ out_distances,
1533
+ pool);
1534
+ }
1535
+
914
1536
  } // namespace faiss