faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -16,17 +16,12 @@
16
16
 
17
17
  #include <faiss/IndexFlat.h>
18
18
  #include <faiss/VectorTransform.h>
19
- #include <faiss/impl/AuxIndexStructures.h>
20
19
  #include <faiss/impl/FaissAssert.h>
21
- #include <faiss/utils/Heap.h>
20
+ #include <faiss/impl/residual_quantizer_encode_steps.h>
22
21
  #include <faiss/utils/distances.h>
23
22
  #include <faiss/utils/hamming.h>
24
23
  #include <faiss/utils/utils.h>
25
24
 
26
- #include <faiss/utils/simdlib.h>
27
-
28
- #include <faiss/utils/approx_topk/approx_topk.h>
29
-
30
25
  extern "C" {
31
26
 
32
27
  // general matrix multiplication
@@ -125,157 +120,9 @@ void ResidualQuantizer::initialize_from(
125
120
  }
126
121
  }
127
122
 
128
- void beam_search_encode_step(
129
- size_t d,
130
- size_t K,
131
- const float* cent, /// size (K, d)
132
- size_t n,
133
- size_t beam_size,
134
- const float* residuals, /// size (n, beam_size, d)
135
- size_t m,
136
- const int32_t* codes, /// size (n, beam_size, m)
137
- size_t new_beam_size,
138
- int32_t* new_codes, /// size (n, new_beam_size, m + 1)
139
- float* new_residuals, /// size (n, new_beam_size, d)
140
- float* new_distances, /// size (n, new_beam_size)
141
- Index* assign_index,
142
- ApproxTopK_mode_t approx_topk_mode) {
143
- // we have to fill in the whole output matrix
144
- FAISS_THROW_IF_NOT(new_beam_size <= beam_size * K);
145
-
146
- std::vector<float> cent_distances;
147
- std::vector<idx_t> cent_ids;
148
-
149
- if (assign_index) {
150
- // search beam_size distances per query
151
- FAISS_THROW_IF_NOT(assign_index->d == d);
152
- cent_distances.resize(n * beam_size * new_beam_size);
153
- cent_ids.resize(n * beam_size * new_beam_size);
154
- if (assign_index->ntotal != 0) {
155
- // then we assume the codebooks are already added to the index
156
- FAISS_THROW_IF_NOT(assign_index->ntotal == K);
157
- } else {
158
- assign_index->add(K, cent);
159
- }
160
-
161
- // printf("beam_search_encode_step -- mem usage %zd\n",
162
- // get_mem_usage_kb());
163
- assign_index->search(
164
- n * beam_size,
165
- residuals,
166
- new_beam_size,
167
- cent_distances.data(),
168
- cent_ids.data());
169
- } else {
170
- // do one big distance computation
171
- cent_distances.resize(n * beam_size * K);
172
- pairwise_L2sqr(
173
- d, n * beam_size, residuals, K, cent, cent_distances.data());
174
- }
175
- InterruptCallback::check();
176
-
177
- #pragma omp parallel for if (n > 100)
178
- for (int64_t i = 0; i < n; i++) {
179
- const int32_t* codes_i = codes + i * m * beam_size;
180
- int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
181
- const float* residuals_i = residuals + i * d * beam_size;
182
- float* new_residuals_i = new_residuals + i * d * new_beam_size;
183
-
184
- float* new_distances_i = new_distances + i * new_beam_size;
185
- using C = CMax<float, int>;
186
-
187
- if (assign_index) {
188
- const float* cent_distances_i =
189
- cent_distances.data() + i * beam_size * new_beam_size;
190
- const idx_t* cent_ids_i =
191
- cent_ids.data() + i * beam_size * new_beam_size;
192
-
193
- // here we could be a tad more efficient by merging sorted arrays
194
- for (int i = 0; i < new_beam_size; i++) {
195
- new_distances_i[i] = C::neutral();
196
- }
197
- std::vector<int> perm(new_beam_size, -1);
198
- heap_addn<C>(
199
- new_beam_size,
200
- new_distances_i,
201
- perm.data(),
202
- cent_distances_i,
203
- nullptr,
204
- beam_size * new_beam_size);
205
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
206
-
207
- for (int j = 0; j < new_beam_size; j++) {
208
- int js = perm[j] / new_beam_size;
209
- int ls = cent_ids_i[perm[j]];
210
- if (m > 0) {
211
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
212
- }
213
- new_codes_i[m] = ls;
214
- new_codes_i += m + 1;
215
- fvec_sub(
216
- d,
217
- residuals_i + js * d,
218
- cent + ls * d,
219
- new_residuals_i);
220
- new_residuals_i += d;
221
- }
222
-
223
- } else {
224
- const float* cent_distances_i =
225
- cent_distances.data() + i * beam_size * K;
226
- // then we have to select the best results
227
- for (int i = 0; i < new_beam_size; i++) {
228
- new_distances_i[i] = C::neutral();
229
- }
230
- std::vector<int> perm(new_beam_size, -1);
231
-
232
- #define HANDLE_APPROX(NB, BD) \
233
- case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
234
- HeapWithBuckets<C, NB, BD>::bs_addn( \
235
- beam_size, \
236
- K, \
237
- cent_distances_i, \
238
- new_beam_size, \
239
- new_distances_i, \
240
- perm.data()); \
241
- break;
242
-
243
- switch (approx_topk_mode) {
244
- HANDLE_APPROX(8, 3)
245
- HANDLE_APPROX(8, 2)
246
- HANDLE_APPROX(16, 2)
247
- HANDLE_APPROX(32, 2)
248
- default:
249
- heap_addn<C>(
250
- new_beam_size,
251
- new_distances_i,
252
- perm.data(),
253
- cent_distances_i,
254
- nullptr,
255
- beam_size * K);
256
- }
257
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
258
-
259
- #undef HANDLE_APPROX
260
-
261
- for (int j = 0; j < new_beam_size; j++) {
262
- int js = perm[j] / K;
263
- int ls = perm[j] % K;
264
- if (m > 0) {
265
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
266
- }
267
- new_codes_i[m] = ls;
268
- new_codes_i += m + 1;
269
- fvec_sub(
270
- d,
271
- residuals_i + js * d,
272
- cent + ls * d,
273
- new_residuals_i);
274
- new_residuals_i += d;
275
- }
276
- }
277
- }
278
- }
123
+ /****************************************************************
124
+ * Training
125
+ ****************************************************************/
279
126
 
280
127
  void ResidualQuantizer::train(size_t n, const float* x) {
281
128
  codebooks.resize(d * codebook_offsets.back());
@@ -568,180 +415,11 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
568
415
  return mem;
569
416
  }
570
417
 
571
- // a namespace full of preallocated buffers
572
- namespace {
573
-
574
- // Preallocated memory chunk for refine_beam_mp() call
575
- struct RefineBeamMemoryPool {
576
- std::vector<int32_t> new_codes;
577
- std::vector<float> new_residuals;
578
-
579
- std::vector<float> residuals;
580
- std::vector<int32_t> codes;
581
- std::vector<float> distances;
582
- };
583
-
584
- // Preallocated memory chunk for refine_beam_LUT_mp() call
585
- struct RefineBeamLUTMemoryPool {
586
- std::vector<int32_t> new_codes;
587
- std::vector<float> new_distances;
588
-
589
- std::vector<int32_t> codes;
590
- std::vector<float> distances;
591
- };
592
-
593
- // this is for use_beam_LUT == 0 in compute_codes_add_centroids_mp_lut0() call
594
- struct ComputeCodesAddCentroidsLUT0MemoryPool {
595
- std::vector<int32_t> codes;
596
- std::vector<float> norms;
597
- std::vector<float> distances;
598
- std::vector<float> residuals;
599
- RefineBeamMemoryPool refine_beam_pool;
600
- };
601
-
602
- // this is for use_beam_LUT == 1 in compute_codes_add_centroids_mp_lut1() call
603
- struct ComputeCodesAddCentroidsLUT1MemoryPool {
604
- std::vector<int32_t> codes;
605
- std::vector<float> distances;
606
- std::vector<float> query_norms;
607
- std::vector<float> query_cp;
608
- std::vector<float> residuals;
609
- RefineBeamLUTMemoryPool refine_beam_lut_pool;
610
- };
611
-
612
- } // namespace
613
-
614
- // forward declaration
615
- void refine_beam_mp(
616
- const ResidualQuantizer& rq,
617
- size_t n,
618
- size_t beam_size,
619
- const float* x,
620
- int out_beam_size,
621
- int32_t* out_codes,
622
- float* out_residuals,
623
- float* out_distances,
624
- RefineBeamMemoryPool& pool);
625
-
626
- // forward declaration
627
- void refine_beam_LUT_mp(
628
- const ResidualQuantizer& rq,
629
- size_t n,
630
- const float* query_norms, // size n
631
- const float* query_cp, //
632
- int out_beam_size,
633
- int32_t* out_codes,
634
- float* out_distances,
635
- RefineBeamLUTMemoryPool& pool);
636
-
637
- // this is for use_beam_LUT == 0
638
- void compute_codes_add_centroids_mp_lut0(
639
- const ResidualQuantizer& rq,
640
- const float* x,
641
- uint8_t* codes_out,
642
- size_t n,
643
- const float* centroids,
644
- ComputeCodesAddCentroidsLUT0MemoryPool& pool) {
645
- pool.codes.resize(rq.max_beam_size * rq.M * n);
646
- pool.distances.resize(rq.max_beam_size * n);
418
+ /****************************************************************
419
+ * Encoding
420
+ ****************************************************************/
647
421
 
648
- pool.residuals.resize(rq.max_beam_size * n * rq.d);
649
-
650
- refine_beam_mp(
651
- rq,
652
- n,
653
- 1,
654
- x,
655
- rq.max_beam_size,
656
- pool.codes.data(),
657
- pool.residuals.data(),
658
- pool.distances.data(),
659
- pool.refine_beam_pool);
660
-
661
- if (rq.search_type == ResidualQuantizer::ST_norm_float ||
662
- rq.search_type == ResidualQuantizer::ST_norm_qint8 ||
663
- rq.search_type == ResidualQuantizer::ST_norm_qint4) {
664
- pool.norms.resize(n);
665
- // recover the norms of reconstruction as
666
- // || original_vector - residual ||^2
667
- for (size_t i = 0; i < n; i++) {
668
- pool.norms[i] = fvec_L2sqr(
669
- x + i * rq.d,
670
- pool.residuals.data() + i * rq.max_beam_size * rq.d,
671
- rq.d);
672
- }
673
- }
674
-
675
- // pack only the first code of the beam
676
- // (hence the ld_codes=M * max_beam_size)
677
- rq.pack_codes(
678
- n,
679
- pool.codes.data(),
680
- codes_out,
681
- rq.M * rq.max_beam_size,
682
- (pool.norms.size() > 0) ? pool.norms.data() : nullptr,
683
- centroids);
684
- }
685
-
686
- // use_beam_LUT == 1
687
- void compute_codes_add_centroids_mp_lut1(
688
- const ResidualQuantizer& rq,
689
- const float* x,
690
- uint8_t* codes_out,
691
- size_t n,
692
- const float* centroids,
693
- ComputeCodesAddCentroidsLUT1MemoryPool& pool) {
694
- //
695
- pool.codes.resize(rq.max_beam_size * rq.M * n);
696
- pool.distances.resize(rq.max_beam_size * n);
697
-
698
- FAISS_THROW_IF_NOT_MSG(
699
- rq.codebook_cross_products.size() ==
700
- rq.total_codebook_size * rq.total_codebook_size,
701
- "call compute_codebook_tables first");
702
-
703
- pool.query_norms.resize(n);
704
- fvec_norms_L2sqr(pool.query_norms.data(), x, rq.d, n);
705
-
706
- pool.query_cp.resize(n * rq.total_codebook_size);
707
- {
708
- FINTEGER ti = rq.total_codebook_size, di = rq.d, ni = n;
709
- float zero = 0, one = 1;
710
- sgemm_("Transposed",
711
- "Not transposed",
712
- &ti,
713
- &ni,
714
- &di,
715
- &one,
716
- rq.codebooks.data(),
717
- &di,
718
- x,
719
- &di,
720
- &zero,
721
- pool.query_cp.data(),
722
- &ti);
723
- }
724
-
725
- refine_beam_LUT_mp(
726
- rq,
727
- n,
728
- pool.query_norms.data(),
729
- pool.query_cp.data(),
730
- rq.max_beam_size,
731
- pool.codes.data(),
732
- pool.distances.data(),
733
- pool.refine_beam_lut_pool);
734
-
735
- // pack only the first code of the beam
736
- // (hence the ld_codes=M * max_beam_size)
737
- rq.pack_codes(
738
- n,
739
- pool.codes.data(),
740
- codes_out,
741
- rq.M * rq.max_beam_size,
742
- nullptr,
743
- centroids);
744
- }
422
+ using namespace rq_encode_steps;
745
423
 
746
424
  void ResidualQuantizer::compute_codes_add_centroids(
747
425
  const float* x,
@@ -769,11 +447,6 @@ void ResidualQuantizer::compute_codes_add_centroids(
769
447
  cent = centroids + i0 * d;
770
448
  }
771
449
 
772
- // compute_codes_add_centroids(
773
- // x + i0 * d,
774
- // codes_out + i0 * code_size,
775
- // i1 - i0,
776
- // cent);
777
450
  if (use_beam_LUT == 0) {
778
451
  compute_codes_add_centroids_mp_lut0(
779
452
  *this,
@@ -794,147 +467,6 @@ void ResidualQuantizer::compute_codes_add_centroids(
794
467
  }
795
468
  }
796
469
 
797
- void refine_beam_mp(
798
- const ResidualQuantizer& rq,
799
- size_t n,
800
- size_t beam_size,
801
- const float* x,
802
- int out_beam_size,
803
- int32_t* out_codes,
804
- float* out_residuals,
805
- float* out_distances,
806
- RefineBeamMemoryPool& pool) {
807
- int cur_beam_size = beam_size;
808
-
809
- double t0 = getmillisecs();
810
-
811
- // find the max_beam_size
812
- int max_beam_size = 0;
813
- {
814
- int tmp_beam_size = cur_beam_size;
815
- for (int m = 0; m < rq.M; m++) {
816
- int K = 1 << rq.nbits[m];
817
- int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
818
- tmp_beam_size = new_beam_size;
819
-
820
- if (max_beam_size < new_beam_size) {
821
- max_beam_size = new_beam_size;
822
- }
823
- }
824
- }
825
-
826
- // preallocate buffers
827
- pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
828
- pool.new_residuals.resize(n * max_beam_size * rq.d);
829
-
830
- pool.codes.resize(n * max_beam_size * (rq.M + 1));
831
- pool.distances.resize(n * max_beam_size);
832
- pool.residuals.resize(n * rq.d * max_beam_size);
833
-
834
- for (size_t i = 0; i < n * rq.d * beam_size; i++) {
835
- pool.residuals[i] = x[i];
836
- }
837
-
838
- // set up pointers to buffers
839
- int32_t* __restrict codes_ptr = pool.codes.data();
840
- float* __restrict residuals_ptr = pool.residuals.data();
841
-
842
- int32_t* __restrict new_codes_ptr = pool.new_codes.data();
843
- float* __restrict new_residuals_ptr = pool.new_residuals.data();
844
-
845
- // index
846
- std::unique_ptr<Index> assign_index;
847
- if (rq.assign_index_factory) {
848
- assign_index.reset((*rq.assign_index_factory)(rq.d));
849
- } else {
850
- assign_index.reset(new IndexFlatL2(rq.d));
851
- }
852
-
853
- // main loop
854
- size_t codes_size = 0;
855
- size_t distances_size = 0;
856
- size_t residuals_size = 0;
857
-
858
- for (int m = 0; m < rq.M; m++) {
859
- int K = 1 << rq.nbits[m];
860
-
861
- const float* __restrict codebooks_m =
862
- rq.codebooks.data() + rq.codebook_offsets[m] * rq.d;
863
-
864
- const int new_beam_size = std::min(cur_beam_size * K, out_beam_size);
865
-
866
- codes_size = n * new_beam_size * (m + 1);
867
- residuals_size = n * new_beam_size * rq.d;
868
- distances_size = n * new_beam_size;
869
-
870
- beam_search_encode_step(
871
- rq.d,
872
- K,
873
- codebooks_m,
874
- n,
875
- cur_beam_size,
876
- // residuals.data(),
877
- residuals_ptr,
878
- m,
879
- // codes.data(),
880
- codes_ptr,
881
- new_beam_size,
882
- // new_codes.data(),
883
- new_codes_ptr,
884
- // new_residuals.data(),
885
- new_residuals_ptr,
886
- pool.distances.data(),
887
- assign_index.get(),
888
- rq.approx_topk_mode);
889
-
890
- assign_index->reset();
891
-
892
- std::swap(codes_ptr, new_codes_ptr);
893
- std::swap(residuals_ptr, new_residuals_ptr);
894
-
895
- cur_beam_size = new_beam_size;
896
-
897
- if (rq.verbose) {
898
- float sum_distances = 0;
899
- // for (int j = 0; j < distances.size(); j++) {
900
- // sum_distances += distances[j];
901
- // }
902
- for (int j = 0; j < distances_size; j++) {
903
- sum_distances += pool.distances[j];
904
- }
905
-
906
- printf("[%.3f s] encode stage %d, %d bits, "
907
- "total error %g, beam_size %d\n",
908
- (getmillisecs() - t0) / 1000,
909
- m,
910
- int(rq.nbits[m]),
911
- sum_distances,
912
- cur_beam_size);
913
- }
914
- }
915
-
916
- if (out_codes) {
917
- // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
918
- memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
919
- }
920
- if (out_residuals) {
921
- // memcpy(out_residuals,
922
- // residuals.data(),
923
- // residuals.size() * sizeof(residuals[0]));
924
- memcpy(out_residuals,
925
- residuals_ptr,
926
- residuals_size * sizeof(*residuals_ptr));
927
- }
928
- if (out_distances) {
929
- // memcpy(out_distances,
930
- // distances.data(),
931
- // distances.size() * sizeof(distances[0]));
932
- memcpy(out_distances,
933
- pool.distances.data(),
934
- distances_size * sizeof(pool.distances[0]));
935
- }
936
- }
937
-
938
470
  void ResidualQuantizer::refine_beam(
939
471
  size_t n,
940
472
  size_t beam_size,
@@ -961,556 +493,36 @@ void ResidualQuantizer::refine_beam(
961
493
  *******************************************************************/
962
494
 
963
495
  void ResidualQuantizer::compute_codebook_tables() {
964
- codebook_cross_products.resize(total_codebook_size * total_codebook_size);
965
496
  cent_norms.resize(total_codebook_size);
966
- // stricly speaking we could use ssyrk
967
- {
968
- FINTEGER ni = total_codebook_size;
497
+ fvec_norms_L2sqr(
498
+ cent_norms.data(), codebooks.data(), d, total_codebook_size);
499
+ size_t cross_table_size = 0;
500
+ for (int m = 0; m < M; m++) {
501
+ size_t K = (size_t)1 << nbits[m];
502
+ cross_table_size += K * codebook_offsets[m];
503
+ }
504
+ codebook_cross_products.resize(cross_table_size);
505
+ size_t ofs = 0;
506
+ for (int m = 1; m < M; m++) {
507
+ FINTEGER ki = (size_t)1 << nbits[m];
508
+ FINTEGER kk = codebook_offsets[m];
969
509
  FINTEGER di = d;
970
510
  float zero = 0, one = 1;
511
+ assert(ofs + ki * kk <= cross_table_size);
971
512
  sgemm_("Transposed",
972
513
  "Not transposed",
973
- &ni,
974
- &ni,
514
+ &ki,
515
+ &kk,
975
516
  &di,
976
517
  &one,
977
- codebooks.data(),
518
+ codebooks.data() + d * kk,
978
519
  &di,
979
520
  codebooks.data(),
980
521
  &di,
981
522
  &zero,
982
- codebook_cross_products.data(),
983
- &ni);
984
- }
985
- for (size_t i = 0; i < total_codebook_size; i++) {
986
- cent_norms[i] = codebook_cross_products[i + i * total_codebook_size];
987
- }
988
- }
989
-
990
- namespace {
991
-
992
- template <size_t M, size_t NK>
993
- void accum_and_store_tab(
994
- const size_t m_offset,
995
- const float* const __restrict codebook_cross_norms,
996
- const uint64_t* const __restrict codebook_offsets,
997
- const int32_t* const __restrict codes_i,
998
- const size_t b,
999
- const size_t ldc,
1000
- const size_t K,
1001
- float* const __restrict output) {
1002
- // load pointers into registers
1003
- const float* cbs[M];
1004
- for (size_t ij = 0; ij < M; ij++) {
1005
- const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
1006
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1007
- }
1008
-
1009
- // do accumulation in registers using SIMD.
1010
- // It is possible that compiler may be smart enough so that
1011
- // this manual SIMD unrolling might be unneeded.
1012
- #if defined(__AVX2__) || defined(__aarch64__)
1013
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
1014
-
1015
- // process in chunks of size (8 * NK) floats
1016
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1017
- simd8float32 regs[NK];
1018
- for (size_t ik = 0; ik < NK; ik++) {
1019
- regs[ik].loadu(cbs[0] + kk + ik * 8);
1020
- }
1021
-
1022
- for (size_t ij = 1; ij < M; ij++) {
1023
- for (size_t ik = 0; ik < NK; ik++) {
1024
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1025
- }
1026
- }
1027
-
1028
- // write the result
1029
- for (size_t ik = 0; ik < NK; ik++) {
1030
- regs[ik].storeu(output + kk + ik * 8);
1031
- }
1032
- }
1033
- #else
1034
- const size_t K8 = 0;
1035
- #endif
1036
-
1037
- // process leftovers
1038
- for (size_t kk = K8; kk < K; kk++) {
1039
- float reg = cbs[0][kk];
1040
- for (size_t ij = 1; ij < M; ij++) {
1041
- reg += cbs[ij][kk];
1042
- }
1043
- output[b * K + kk] = reg;
1044
- }
1045
- }
1046
-
1047
- template <size_t M, size_t NK>
1048
- void accum_and_add_tab(
1049
- const size_t m_offset,
1050
- const float* const __restrict codebook_cross_norms,
1051
- const uint64_t* const __restrict codebook_offsets,
1052
- const int32_t* const __restrict codes_i,
1053
- const size_t b,
1054
- const size_t ldc,
1055
- const size_t K,
1056
- float* const __restrict output) {
1057
- // load pointers into registers
1058
- const float* cbs[M];
1059
- for (size_t ij = 0; ij < M; ij++) {
1060
- const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
1061
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1062
- }
1063
-
1064
- // do accumulation in registers using SIMD.
1065
- // It is possible that compiler may be smart enough so that
1066
- // this manual SIMD unrolling might be unneeded.
1067
- #if defined(__AVX2__) || defined(__aarch64__)
1068
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
1069
-
1070
- // process in chunks of size (8 * NK) floats
1071
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1072
- simd8float32 regs[NK];
1073
- for (size_t ik = 0; ik < NK; ik++) {
1074
- regs[ik].loadu(cbs[0] + kk + ik * 8);
1075
- }
1076
-
1077
- for (size_t ij = 1; ij < M; ij++) {
1078
- for (size_t ik = 0; ik < NK; ik++) {
1079
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1080
- }
1081
- }
1082
-
1083
- // write the result
1084
- for (size_t ik = 0; ik < NK; ik++) {
1085
- simd8float32 existing(output + kk + ik * 8);
1086
- existing += regs[ik];
1087
- existing.storeu(output + kk + ik * 8);
1088
- }
1089
- }
1090
- #else
1091
- const size_t K8 = 0;
1092
- #endif
1093
-
1094
- // process leftovers
1095
- for (size_t kk = K8; kk < K; kk++) {
1096
- float reg = cbs[0][kk];
1097
- for (size_t ij = 1; ij < M; ij++) {
1098
- reg += cbs[ij][kk];
1099
- }
1100
- output[b * K + kk] += reg;
1101
- }
1102
- }
1103
-
1104
- template <size_t M, size_t NK>
1105
- void accum_and_finalize_tab(
1106
- const float* const __restrict codebook_cross_norms,
1107
- const uint64_t* const __restrict codebook_offsets,
1108
- const int32_t* const __restrict codes_i,
1109
- const size_t b,
1110
- const size_t ldc,
1111
- const size_t K,
1112
- const float* const __restrict distances_i,
1113
- const float* const __restrict cd_common,
1114
- float* const __restrict output) {
1115
- // load pointers into registers
1116
- const float* cbs[M];
1117
- for (size_t ij = 0; ij < M; ij++) {
1118
- const size_t code = static_cast<size_t>(codes_i[b * M + ij]);
1119
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1120
- }
1121
-
1122
- // do accumulation in registers using SIMD.
1123
- // It is possible that compiler may be smart enough so that
1124
- // this manual SIMD unrolling might be unneeded.
1125
- #if defined(__AVX2__) || defined(__aarch64__)
1126
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
1127
-
1128
- // process in chunks of size (8 * NK) floats
1129
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1130
- simd8float32 regs[NK];
1131
- for (size_t ik = 0; ik < NK; ik++) {
1132
- regs[ik].loadu(cbs[0] + kk + ik * 8);
1133
- }
1134
-
1135
- for (size_t ij = 1; ij < M; ij++) {
1136
- for (size_t ik = 0; ik < NK; ik++) {
1137
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1138
- }
1139
- }
1140
-
1141
- simd8float32 two(2.0f);
1142
- for (size_t ik = 0; ik < NK; ik++) {
1143
- // cent_distances[b * K + k] = distances_i[b] + cd_common[k]
1144
- // + 2 * dp[k];
1145
-
1146
- simd8float32 common_v(cd_common + kk + ik * 8);
1147
- common_v = fmadd(two, regs[ik], common_v);
1148
-
1149
- common_v += simd8float32(distances_i[b]);
1150
- common_v.storeu(output + b * K + kk + ik * 8);
1151
- }
1152
- }
1153
- #else
1154
- const size_t K8 = 0;
1155
- #endif
1156
-
1157
- // process leftovers
1158
- for (size_t kk = K8; kk < K; kk++) {
1159
- float reg = cbs[0][kk];
1160
- for (size_t ij = 1; ij < M; ij++) {
1161
- reg += cbs[ij][kk];
1162
- }
1163
-
1164
- output[b * K + kk] = distances_i[b] + cd_common[kk] + 2 * reg;
1165
- }
1166
- }
1167
-
1168
- } // namespace
1169
-
1170
- void beam_search_encode_step_tab(
1171
- size_t K,
1172
- size_t n,
1173
- size_t beam_size, // input sizes
1174
- const float* codebook_cross_norms, // size K * ldc
1175
- size_t ldc, // >= K
1176
- const uint64_t* codebook_offsets, // m
1177
- const float* query_cp, // size n * ldqc
1178
- size_t ldqc, // >= K
1179
- const float* cent_norms_i, // size K
1180
- size_t m,
1181
- const int32_t* codes, // n * beam_size * m
1182
- const float* distances, // n * beam_size
1183
- size_t new_beam_size,
1184
- int32_t* new_codes, // n * new_beam_size * (m + 1)
1185
- float* new_distances, // n * new_beam_size
1186
- ApproxTopK_mode_t approx_topk_mode) //
1187
- {
1188
- FAISS_THROW_IF_NOT(ldc >= K);
1189
-
1190
- #pragma omp parallel for if (n > 100) schedule(dynamic)
1191
- for (int64_t i = 0; i < n; i++) {
1192
- std::vector<float> cent_distances(beam_size * K);
1193
- std::vector<float> cd_common(K);
1194
-
1195
- const int32_t* codes_i = codes + i * m * beam_size;
1196
- const float* query_cp_i = query_cp + i * ldqc;
1197
- const float* distances_i = distances + i * beam_size;
1198
-
1199
- for (size_t k = 0; k < K; k++) {
1200
- cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
1201
- }
1202
-
1203
- /*
1204
- // This is the baseline implementation. Its primary flaw
1205
- // that it writes way too many info to the temporary buffer
1206
- // called dp.
1207
- //
1208
- // This baseline code is kept intentionally because it is easy to
1209
- // understand what an optimized version optimizes exactly.
1210
- //
1211
- for (size_t b = 0; b < beam_size; b++) {
1212
- std::vector<float> dp(K);
1213
-
1214
- for (size_t m1 = 0; m1 < m; m1++) {
1215
- size_t c = codes_i[b * m + m1];
1216
- const float* cb =
1217
- &codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
1218
- fvec_add(K, cb, dp.data(), dp.data());
1219
- }
1220
-
1221
- for (size_t k = 0; k < K; k++) {
1222
- cent_distances[b * K + k] =
1223
- distances_i[b] + cd_common[k] + 2 * dp[k];
1224
- }
1225
- }
1226
- */
1227
-
1228
- // An optimized implementation that avoids using a temporary buffer
1229
- // and does the accumulation in registers.
1230
-
1231
- // Compute a sum of NK AQ codes.
1232
- #define ACCUM_AND_FINALIZE_TAB(NK) \
1233
- case NK: \
1234
- for (size_t b = 0; b < beam_size; b++) { \
1235
- accum_and_finalize_tab<NK, 4>( \
1236
- codebook_cross_norms, \
1237
- codebook_offsets, \
1238
- codes_i, \
1239
- b, \
1240
- ldc, \
1241
- K, \
1242
- distances_i, \
1243
- cd_common.data(), \
1244
- cent_distances.data()); \
1245
- } \
1246
- break;
1247
-
1248
- // this version contains many switch-case scenarios, but
1249
- // they won't affect branch predictor.
1250
- switch (m) {
1251
- case 0:
1252
- // trivial case
1253
- for (size_t b = 0; b < beam_size; b++) {
1254
- for (size_t k = 0; k < K; k++) {
1255
- cent_distances[b * K + k] =
1256
- distances_i[b] + cd_common[k];
1257
- }
1258
- }
1259
- break;
1260
-
1261
- ACCUM_AND_FINALIZE_TAB(1)
1262
- ACCUM_AND_FINALIZE_TAB(2)
1263
- ACCUM_AND_FINALIZE_TAB(3)
1264
- ACCUM_AND_FINALIZE_TAB(4)
1265
- ACCUM_AND_FINALIZE_TAB(5)
1266
- ACCUM_AND_FINALIZE_TAB(6)
1267
- ACCUM_AND_FINALIZE_TAB(7)
1268
-
1269
- default: {
1270
- // m >= 8 case.
1271
-
1272
- // A temporary buffer has to be used due to the lack of
1273
- // registers. But we'll try to accumulate up to 8 AQ codes in
1274
- // registers and issue a single write operation to the buffer,
1275
- // while the baseline does no accumulation. So, the number of
1276
- // write operations to the temporary buffer is reduced 8x.
1277
-
1278
- // allocate a temporary buffer
1279
- std::vector<float> dp(K);
1280
-
1281
- for (size_t b = 0; b < beam_size; b++) {
1282
- // Initialize it. Compute a sum of first 8 AQ codes
1283
- // because m >= 8 .
1284
- accum_and_store_tab<8, 4>(
1285
- m,
1286
- codebook_cross_norms,
1287
- codebook_offsets,
1288
- codes_i,
1289
- b,
1290
- ldc,
1291
- K,
1292
- dp.data());
1293
-
1294
- #define ACCUM_AND_ADD_TAB(NK) \
1295
- case NK: \
1296
- accum_and_add_tab<NK, 4>( \
1297
- m, \
1298
- codebook_cross_norms, \
1299
- codebook_offsets + im, \
1300
- codes_i + im, \
1301
- b, \
1302
- ldc, \
1303
- K, \
1304
- dp.data()); \
1305
- break;
1306
-
1307
- // accumulate up to 8 additional AQ codes into
1308
- // a temporary buffer
1309
- for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
1310
- size_t m_left = m - im;
1311
- if (m_left > 8) {
1312
- m_left = 8;
1313
- }
1314
-
1315
- switch (m_left) {
1316
- ACCUM_AND_ADD_TAB(1)
1317
- ACCUM_AND_ADD_TAB(2)
1318
- ACCUM_AND_ADD_TAB(3)
1319
- ACCUM_AND_ADD_TAB(4)
1320
- ACCUM_AND_ADD_TAB(5)
1321
- ACCUM_AND_ADD_TAB(6)
1322
- ACCUM_AND_ADD_TAB(7)
1323
- ACCUM_AND_ADD_TAB(8)
1324
- }
1325
- }
1326
-
1327
- // done. finalize the result
1328
- for (size_t k = 0; k < K; k++) {
1329
- cent_distances[b * K + k] =
1330
- distances_i[b] + cd_common[k] + 2 * dp[k];
1331
- }
1332
- }
1333
- }
1334
- }
1335
-
1336
- // the optimized implementation ends here
1337
-
1338
- using C = CMax<float, int>;
1339
- int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
1340
- float* new_distances_i = new_distances + i * new_beam_size;
1341
-
1342
- const float* cent_distances_i = cent_distances.data();
1343
-
1344
- // then we have to select the best results
1345
- for (int i = 0; i < new_beam_size; i++) {
1346
- new_distances_i[i] = C::neutral();
1347
- }
1348
- std::vector<int> perm(new_beam_size, -1);
1349
-
1350
- #define HANDLE_APPROX(NB, BD) \
1351
- case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
1352
- HeapWithBuckets<C, NB, BD>::bs_addn( \
1353
- beam_size, \
1354
- K, \
1355
- cent_distances_i, \
1356
- new_beam_size, \
1357
- new_distances_i, \
1358
- perm.data()); \
1359
- break;
1360
-
1361
- switch (approx_topk_mode) {
1362
- HANDLE_APPROX(8, 3)
1363
- HANDLE_APPROX(8, 2)
1364
- HANDLE_APPROX(16, 2)
1365
- HANDLE_APPROX(32, 2)
1366
- default:
1367
- heap_addn<C>(
1368
- new_beam_size,
1369
- new_distances_i,
1370
- perm.data(),
1371
- cent_distances_i,
1372
- nullptr,
1373
- beam_size * K);
1374
- break;
1375
- }
1376
-
1377
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
1378
-
1379
- #undef HANDLE_APPROX
1380
-
1381
- for (int j = 0; j < new_beam_size; j++) {
1382
- int js = perm[j] / K;
1383
- int ls = perm[j] % K;
1384
- if (m > 0) {
1385
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
1386
- }
1387
- new_codes_i[m] = ls;
1388
- new_codes_i += m + 1;
1389
- }
1390
- }
1391
- }
1392
-
1393
- //
1394
- void refine_beam_LUT_mp(
1395
- const ResidualQuantizer& rq,
1396
- size_t n,
1397
- const float* query_norms, // size n
1398
- const float* query_cp, //
1399
- int out_beam_size,
1400
- int32_t* out_codes,
1401
- float* out_distances,
1402
- RefineBeamLUTMemoryPool& pool) {
1403
- int beam_size = 1;
1404
-
1405
- double t0 = getmillisecs();
1406
-
1407
- // find the max_beam_size
1408
- int max_beam_size = 0;
1409
- {
1410
- int tmp_beam_size = beam_size;
1411
- for (int m = 0; m < rq.M; m++) {
1412
- int K = 1 << rq.nbits[m];
1413
- int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
1414
- tmp_beam_size = new_beam_size;
1415
-
1416
- if (max_beam_size < new_beam_size) {
1417
- max_beam_size = new_beam_size;
1418
- }
1419
- }
1420
- }
1421
-
1422
- // preallocate buffers
1423
- pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
1424
- pool.new_distances.resize(n * max_beam_size);
1425
-
1426
- pool.codes.resize(n * max_beam_size * (rq.M + 1));
1427
- pool.distances.resize(n * max_beam_size);
1428
-
1429
- for (size_t i = 0; i < n; i++) {
1430
- pool.distances[i] = query_norms[i];
1431
- }
1432
-
1433
- // set up pointers to buffers
1434
- int32_t* __restrict new_codes_ptr = pool.new_codes.data();
1435
- float* __restrict new_distances_ptr = pool.new_distances.data();
1436
-
1437
- int32_t* __restrict codes_ptr = pool.codes.data();
1438
- float* __restrict distances_ptr = pool.distances.data();
1439
-
1440
- // main loop
1441
- size_t codes_size = 0;
1442
- size_t distances_size = 0;
1443
- for (int m = 0; m < rq.M; m++) {
1444
- int K = 1 << rq.nbits[m];
1445
-
1446
- // it is guaranteed that (new_beam_size <= than max_beam_size) ==
1447
- // true
1448
- int new_beam_size = std::min(beam_size * K, out_beam_size);
1449
-
1450
- // std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
1451
- // std::vector<float> new_distances(n * new_beam_size);
1452
-
1453
- codes_size = n * new_beam_size * (m + 1);
1454
- distances_size = n * new_beam_size;
1455
-
1456
- beam_search_encode_step_tab(
1457
- K,
1458
- n,
1459
- beam_size,
1460
- rq.codebook_cross_products.data() + rq.codebook_offsets[m],
1461
- rq.total_codebook_size,
1462
- rq.codebook_offsets.data(),
1463
- query_cp + rq.codebook_offsets[m],
1464
- rq.total_codebook_size,
1465
- rq.cent_norms.data() + rq.codebook_offsets[m],
1466
- m,
1467
- // codes.data(),
1468
- codes_ptr,
1469
- // distances.data(),
1470
- distances_ptr,
1471
- new_beam_size,
1472
- // new_codes.data(),
1473
- new_codes_ptr,
1474
- // new_distances.data()
1475
- new_distances_ptr,
1476
- rq.approx_topk_mode);
1477
-
1478
- // codes.swap(new_codes);
1479
- std::swap(codes_ptr, new_codes_ptr);
1480
- // distances.swap(new_distances);
1481
- std::swap(distances_ptr, new_distances_ptr);
1482
-
1483
- beam_size = new_beam_size;
1484
-
1485
- if (rq.verbose) {
1486
- float sum_distances = 0;
1487
- // for (int j = 0; j < distances.size(); j++) {
1488
- // sum_distances += distances[j];
1489
- // }
1490
- for (int j = 0; j < distances_size; j++) {
1491
- sum_distances += distances_ptr[j];
1492
- }
1493
- printf("[%.3f s] encode stage %d, %d bits, "
1494
- "total error %g, beam_size %d\n",
1495
- (getmillisecs() - t0) / 1000,
1496
- m,
1497
- int(rq.nbits[m]),
1498
- sum_distances,
1499
- beam_size);
1500
- }
1501
- }
1502
-
1503
- if (out_codes) {
1504
- // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
1505
- memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
1506
- }
1507
- if (out_distances) {
1508
- // memcpy(out_distances,
1509
- // distances.data(),
1510
- // distances.size() * sizeof(distances[0]));
1511
- memcpy(out_distances,
1512
- distances_ptr,
1513
- distances_size * sizeof(*distances_ptr));
523
+ codebook_cross_products.data() + ofs,
524
+ &ki);
525
+ ofs += ki * kk;
1514
526
  }
1515
527
  }
1516
528