faiss 0.2.7 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +10 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/lib/faiss.rb +1 -1
  11. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  12. data/vendor/faiss/faiss/AutoTune.h +0 -1
  13. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  14. data/vendor/faiss/faiss/Clustering.h +31 -21
  15. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  16. data/vendor/faiss/faiss/Index.cpp +1 -1
  17. data/vendor/faiss/faiss/Index.h +20 -5
  18. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  20. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  21. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  22. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  23. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  34. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  38. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  59. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  60. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  61. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  62. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  63. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  64. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  65. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  66. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  67. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  69. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  70. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  71. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  72. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  73. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  74. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  75. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  76. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  77. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  78. data/vendor/faiss/faiss/clone_index.h +3 -0
  79. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  80. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  81. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  82. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  83. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  88. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  90. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  92. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  93. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  97. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  98. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  99. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  101. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  103. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  104. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  105. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  106. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  107. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  108. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  109. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  110. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  111. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  112. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  113. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  115. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  118. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  119. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  121. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  124. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  125. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  126. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  127. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  128. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  129. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  133. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  135. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  136. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  137. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  138. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  139. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  140. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  141. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  142. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  143. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  144. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  145. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  146. data/vendor/faiss/faiss/utils/distances.h +81 -4
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  148. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  150. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  152. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  153. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  154. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  155. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  156. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  157. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  158. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  159. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  160. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  161. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  162. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  163. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  164. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  165. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  166. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  167. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  168. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  169. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  170. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  171. data/vendor/faiss/faiss/utils/utils.h +57 -20
  172. metadata +11 -4
@@ -16,17 +16,12 @@
16
16
 
17
17
  #include <faiss/IndexFlat.h>
18
18
  #include <faiss/VectorTransform.h>
19
- #include <faiss/impl/AuxIndexStructures.h>
20
19
  #include <faiss/impl/FaissAssert.h>
21
- #include <faiss/utils/Heap.h>
20
+ #include <faiss/impl/residual_quantizer_encode_steps.h>
22
21
  #include <faiss/utils/distances.h>
23
22
  #include <faiss/utils/hamming.h>
24
23
  #include <faiss/utils/utils.h>
25
24
 
26
- #include <faiss/utils/simdlib.h>
27
-
28
- #include <faiss/utils/approx_topk/approx_topk.h>
29
-
30
25
  extern "C" {
31
26
 
32
27
  // general matrix multiplication
@@ -125,157 +120,9 @@ void ResidualQuantizer::initialize_from(
125
120
  }
126
121
  }
127
122
 
128
- void beam_search_encode_step(
129
- size_t d,
130
- size_t K,
131
- const float* cent, /// size (K, d)
132
- size_t n,
133
- size_t beam_size,
134
- const float* residuals, /// size (n, beam_size, d)
135
- size_t m,
136
- const int32_t* codes, /// size (n, beam_size, m)
137
- size_t new_beam_size,
138
- int32_t* new_codes, /// size (n, new_beam_size, m + 1)
139
- float* new_residuals, /// size (n, new_beam_size, d)
140
- float* new_distances, /// size (n, new_beam_size)
141
- Index* assign_index,
142
- ApproxTopK_mode_t approx_topk_mode) {
143
- // we have to fill in the whole output matrix
144
- FAISS_THROW_IF_NOT(new_beam_size <= beam_size * K);
145
-
146
- std::vector<float> cent_distances;
147
- std::vector<idx_t> cent_ids;
148
-
149
- if (assign_index) {
150
- // search beam_size distances per query
151
- FAISS_THROW_IF_NOT(assign_index->d == d);
152
- cent_distances.resize(n * beam_size * new_beam_size);
153
- cent_ids.resize(n * beam_size * new_beam_size);
154
- if (assign_index->ntotal != 0) {
155
- // then we assume the codebooks are already added to the index
156
- FAISS_THROW_IF_NOT(assign_index->ntotal == K);
157
- } else {
158
- assign_index->add(K, cent);
159
- }
160
-
161
- // printf("beam_search_encode_step -- mem usage %zd\n",
162
- // get_mem_usage_kb());
163
- assign_index->search(
164
- n * beam_size,
165
- residuals,
166
- new_beam_size,
167
- cent_distances.data(),
168
- cent_ids.data());
169
- } else {
170
- // do one big distance computation
171
- cent_distances.resize(n * beam_size * K);
172
- pairwise_L2sqr(
173
- d, n * beam_size, residuals, K, cent, cent_distances.data());
174
- }
175
- InterruptCallback::check();
176
-
177
- #pragma omp parallel for if (n > 100)
178
- for (int64_t i = 0; i < n; i++) {
179
- const int32_t* codes_i = codes + i * m * beam_size;
180
- int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
181
- const float* residuals_i = residuals + i * d * beam_size;
182
- float* new_residuals_i = new_residuals + i * d * new_beam_size;
183
-
184
- float* new_distances_i = new_distances + i * new_beam_size;
185
- using C = CMax<float, int>;
186
-
187
- if (assign_index) {
188
- const float* cent_distances_i =
189
- cent_distances.data() + i * beam_size * new_beam_size;
190
- const idx_t* cent_ids_i =
191
- cent_ids.data() + i * beam_size * new_beam_size;
192
-
193
- // here we could be a tad more efficient by merging sorted arrays
194
- for (int i = 0; i < new_beam_size; i++) {
195
- new_distances_i[i] = C::neutral();
196
- }
197
- std::vector<int> perm(new_beam_size, -1);
198
- heap_addn<C>(
199
- new_beam_size,
200
- new_distances_i,
201
- perm.data(),
202
- cent_distances_i,
203
- nullptr,
204
- beam_size * new_beam_size);
205
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
206
-
207
- for (int j = 0; j < new_beam_size; j++) {
208
- int js = perm[j] / new_beam_size;
209
- int ls = cent_ids_i[perm[j]];
210
- if (m > 0) {
211
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
212
- }
213
- new_codes_i[m] = ls;
214
- new_codes_i += m + 1;
215
- fvec_sub(
216
- d,
217
- residuals_i + js * d,
218
- cent + ls * d,
219
- new_residuals_i);
220
- new_residuals_i += d;
221
- }
222
-
223
- } else {
224
- const float* cent_distances_i =
225
- cent_distances.data() + i * beam_size * K;
226
- // then we have to select the best results
227
- for (int i = 0; i < new_beam_size; i++) {
228
- new_distances_i[i] = C::neutral();
229
- }
230
- std::vector<int> perm(new_beam_size, -1);
231
-
232
- #define HANDLE_APPROX(NB, BD) \
233
- case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
234
- HeapWithBuckets<C, NB, BD>::bs_addn( \
235
- beam_size, \
236
- K, \
237
- cent_distances_i, \
238
- new_beam_size, \
239
- new_distances_i, \
240
- perm.data()); \
241
- break;
242
-
243
- switch (approx_topk_mode) {
244
- HANDLE_APPROX(8, 3)
245
- HANDLE_APPROX(8, 2)
246
- HANDLE_APPROX(16, 2)
247
- HANDLE_APPROX(32, 2)
248
- default:
249
- heap_addn<C>(
250
- new_beam_size,
251
- new_distances_i,
252
- perm.data(),
253
- cent_distances_i,
254
- nullptr,
255
- beam_size * K);
256
- }
257
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
258
-
259
- #undef HANDLE_APPROX
260
-
261
- for (int j = 0; j < new_beam_size; j++) {
262
- int js = perm[j] / K;
263
- int ls = perm[j] % K;
264
- if (m > 0) {
265
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
266
- }
267
- new_codes_i[m] = ls;
268
- new_codes_i += m + 1;
269
- fvec_sub(
270
- d,
271
- residuals_i + js * d,
272
- cent + ls * d,
273
- new_residuals_i);
274
- new_residuals_i += d;
275
- }
276
- }
277
- }
278
- }
123
+ /****************************************************************
124
+ * Training
125
+ ****************************************************************/
279
126
 
280
127
  void ResidualQuantizer::train(size_t n, const float* x) {
281
128
  codebooks.resize(d * codebook_offsets.back());
@@ -568,180 +415,11 @@ size_t ResidualQuantizer::memory_per_point(int beam_size) const {
568
415
  return mem;
569
416
  }
570
417
 
571
- // a namespace full of preallocated buffers
572
- namespace {
573
-
574
- // Preallocated memory chunk for refine_beam_mp() call
575
- struct RefineBeamMemoryPool {
576
- std::vector<int32_t> new_codes;
577
- std::vector<float> new_residuals;
578
-
579
- std::vector<float> residuals;
580
- std::vector<int32_t> codes;
581
- std::vector<float> distances;
582
- };
583
-
584
- // Preallocated memory chunk for refine_beam_LUT_mp() call
585
- struct RefineBeamLUTMemoryPool {
586
- std::vector<int32_t> new_codes;
587
- std::vector<float> new_distances;
588
-
589
- std::vector<int32_t> codes;
590
- std::vector<float> distances;
591
- };
592
-
593
- // this is for use_beam_LUT == 0 in compute_codes_add_centroids_mp_lut0() call
594
- struct ComputeCodesAddCentroidsLUT0MemoryPool {
595
- std::vector<int32_t> codes;
596
- std::vector<float> norms;
597
- std::vector<float> distances;
598
- std::vector<float> residuals;
599
- RefineBeamMemoryPool refine_beam_pool;
600
- };
601
-
602
- // this is for use_beam_LUT == 1 in compute_codes_add_centroids_mp_lut1() call
603
- struct ComputeCodesAddCentroidsLUT1MemoryPool {
604
- std::vector<int32_t> codes;
605
- std::vector<float> distances;
606
- std::vector<float> query_norms;
607
- std::vector<float> query_cp;
608
- std::vector<float> residuals;
609
- RefineBeamLUTMemoryPool refine_beam_lut_pool;
610
- };
611
-
612
- } // namespace
613
-
614
- // forward declaration
615
- void refine_beam_mp(
616
- const ResidualQuantizer& rq,
617
- size_t n,
618
- size_t beam_size,
619
- const float* x,
620
- int out_beam_size,
621
- int32_t* out_codes,
622
- float* out_residuals,
623
- float* out_distances,
624
- RefineBeamMemoryPool& pool);
625
-
626
- // forward declaration
627
- void refine_beam_LUT_mp(
628
- const ResidualQuantizer& rq,
629
- size_t n,
630
- const float* query_norms, // size n
631
- const float* query_cp, //
632
- int out_beam_size,
633
- int32_t* out_codes,
634
- float* out_distances,
635
- RefineBeamLUTMemoryPool& pool);
636
-
637
- // this is for use_beam_LUT == 0
638
- void compute_codes_add_centroids_mp_lut0(
639
- const ResidualQuantizer& rq,
640
- const float* x,
641
- uint8_t* codes_out,
642
- size_t n,
643
- const float* centroids,
644
- ComputeCodesAddCentroidsLUT0MemoryPool& pool) {
645
- pool.codes.resize(rq.max_beam_size * rq.M * n);
646
- pool.distances.resize(rq.max_beam_size * n);
418
+ /****************************************************************
419
+ * Encoding
420
+ ****************************************************************/
647
421
 
648
- pool.residuals.resize(rq.max_beam_size * n * rq.d);
649
-
650
- refine_beam_mp(
651
- rq,
652
- n,
653
- 1,
654
- x,
655
- rq.max_beam_size,
656
- pool.codes.data(),
657
- pool.residuals.data(),
658
- pool.distances.data(),
659
- pool.refine_beam_pool);
660
-
661
- if (rq.search_type == ResidualQuantizer::ST_norm_float ||
662
- rq.search_type == ResidualQuantizer::ST_norm_qint8 ||
663
- rq.search_type == ResidualQuantizer::ST_norm_qint4) {
664
- pool.norms.resize(n);
665
- // recover the norms of reconstruction as
666
- // || original_vector - residual ||^2
667
- for (size_t i = 0; i < n; i++) {
668
- pool.norms[i] = fvec_L2sqr(
669
- x + i * rq.d,
670
- pool.residuals.data() + i * rq.max_beam_size * rq.d,
671
- rq.d);
672
- }
673
- }
674
-
675
- // pack only the first code of the beam
676
- // (hence the ld_codes=M * max_beam_size)
677
- rq.pack_codes(
678
- n,
679
- pool.codes.data(),
680
- codes_out,
681
- rq.M * rq.max_beam_size,
682
- (pool.norms.size() > 0) ? pool.norms.data() : nullptr,
683
- centroids);
684
- }
685
-
686
- // use_beam_LUT == 1
687
- void compute_codes_add_centroids_mp_lut1(
688
- const ResidualQuantizer& rq,
689
- const float* x,
690
- uint8_t* codes_out,
691
- size_t n,
692
- const float* centroids,
693
- ComputeCodesAddCentroidsLUT1MemoryPool& pool) {
694
- //
695
- pool.codes.resize(rq.max_beam_size * rq.M * n);
696
- pool.distances.resize(rq.max_beam_size * n);
697
-
698
- FAISS_THROW_IF_NOT_MSG(
699
- rq.codebook_cross_products.size() ==
700
- rq.total_codebook_size * rq.total_codebook_size,
701
- "call compute_codebook_tables first");
702
-
703
- pool.query_norms.resize(n);
704
- fvec_norms_L2sqr(pool.query_norms.data(), x, rq.d, n);
705
-
706
- pool.query_cp.resize(n * rq.total_codebook_size);
707
- {
708
- FINTEGER ti = rq.total_codebook_size, di = rq.d, ni = n;
709
- float zero = 0, one = 1;
710
- sgemm_("Transposed",
711
- "Not transposed",
712
- &ti,
713
- &ni,
714
- &di,
715
- &one,
716
- rq.codebooks.data(),
717
- &di,
718
- x,
719
- &di,
720
- &zero,
721
- pool.query_cp.data(),
722
- &ti);
723
- }
724
-
725
- refine_beam_LUT_mp(
726
- rq,
727
- n,
728
- pool.query_norms.data(),
729
- pool.query_cp.data(),
730
- rq.max_beam_size,
731
- pool.codes.data(),
732
- pool.distances.data(),
733
- pool.refine_beam_lut_pool);
734
-
735
- // pack only the first code of the beam
736
- // (hence the ld_codes=M * max_beam_size)
737
- rq.pack_codes(
738
- n,
739
- pool.codes.data(),
740
- codes_out,
741
- rq.M * rq.max_beam_size,
742
- nullptr,
743
- centroids);
744
- }
422
+ using namespace rq_encode_steps;
745
423
 
746
424
  void ResidualQuantizer::compute_codes_add_centroids(
747
425
  const float* x,
@@ -769,11 +447,6 @@ void ResidualQuantizer::compute_codes_add_centroids(
769
447
  cent = centroids + i0 * d;
770
448
  }
771
449
 
772
- // compute_codes_add_centroids(
773
- // x + i0 * d,
774
- // codes_out + i0 * code_size,
775
- // i1 - i0,
776
- // cent);
777
450
  if (use_beam_LUT == 0) {
778
451
  compute_codes_add_centroids_mp_lut0(
779
452
  *this,
@@ -794,147 +467,6 @@ void ResidualQuantizer::compute_codes_add_centroids(
794
467
  }
795
468
  }
796
469
 
797
- void refine_beam_mp(
798
- const ResidualQuantizer& rq,
799
- size_t n,
800
- size_t beam_size,
801
- const float* x,
802
- int out_beam_size,
803
- int32_t* out_codes,
804
- float* out_residuals,
805
- float* out_distances,
806
- RefineBeamMemoryPool& pool) {
807
- int cur_beam_size = beam_size;
808
-
809
- double t0 = getmillisecs();
810
-
811
- // find the max_beam_size
812
- int max_beam_size = 0;
813
- {
814
- int tmp_beam_size = cur_beam_size;
815
- for (int m = 0; m < rq.M; m++) {
816
- int K = 1 << rq.nbits[m];
817
- int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
818
- tmp_beam_size = new_beam_size;
819
-
820
- if (max_beam_size < new_beam_size) {
821
- max_beam_size = new_beam_size;
822
- }
823
- }
824
- }
825
-
826
- // preallocate buffers
827
- pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
828
- pool.new_residuals.resize(n * max_beam_size * rq.d);
829
-
830
- pool.codes.resize(n * max_beam_size * (rq.M + 1));
831
- pool.distances.resize(n * max_beam_size);
832
- pool.residuals.resize(n * rq.d * max_beam_size);
833
-
834
- for (size_t i = 0; i < n * rq.d * beam_size; i++) {
835
- pool.residuals[i] = x[i];
836
- }
837
-
838
- // set up pointers to buffers
839
- int32_t* __restrict codes_ptr = pool.codes.data();
840
- float* __restrict residuals_ptr = pool.residuals.data();
841
-
842
- int32_t* __restrict new_codes_ptr = pool.new_codes.data();
843
- float* __restrict new_residuals_ptr = pool.new_residuals.data();
844
-
845
- // index
846
- std::unique_ptr<Index> assign_index;
847
- if (rq.assign_index_factory) {
848
- assign_index.reset((*rq.assign_index_factory)(rq.d));
849
- } else {
850
- assign_index.reset(new IndexFlatL2(rq.d));
851
- }
852
-
853
- // main loop
854
- size_t codes_size = 0;
855
- size_t distances_size = 0;
856
- size_t residuals_size = 0;
857
-
858
- for (int m = 0; m < rq.M; m++) {
859
- int K = 1 << rq.nbits[m];
860
-
861
- const float* __restrict codebooks_m =
862
- rq.codebooks.data() + rq.codebook_offsets[m] * rq.d;
863
-
864
- const int new_beam_size = std::min(cur_beam_size * K, out_beam_size);
865
-
866
- codes_size = n * new_beam_size * (m + 1);
867
- residuals_size = n * new_beam_size * rq.d;
868
- distances_size = n * new_beam_size;
869
-
870
- beam_search_encode_step(
871
- rq.d,
872
- K,
873
- codebooks_m,
874
- n,
875
- cur_beam_size,
876
- // residuals.data(),
877
- residuals_ptr,
878
- m,
879
- // codes.data(),
880
- codes_ptr,
881
- new_beam_size,
882
- // new_codes.data(),
883
- new_codes_ptr,
884
- // new_residuals.data(),
885
- new_residuals_ptr,
886
- pool.distances.data(),
887
- assign_index.get(),
888
- rq.approx_topk_mode);
889
-
890
- assign_index->reset();
891
-
892
- std::swap(codes_ptr, new_codes_ptr);
893
- std::swap(residuals_ptr, new_residuals_ptr);
894
-
895
- cur_beam_size = new_beam_size;
896
-
897
- if (rq.verbose) {
898
- float sum_distances = 0;
899
- // for (int j = 0; j < distances.size(); j++) {
900
- // sum_distances += distances[j];
901
- // }
902
- for (int j = 0; j < distances_size; j++) {
903
- sum_distances += pool.distances[j];
904
- }
905
-
906
- printf("[%.3f s] encode stage %d, %d bits, "
907
- "total error %g, beam_size %d\n",
908
- (getmillisecs() - t0) / 1000,
909
- m,
910
- int(rq.nbits[m]),
911
- sum_distances,
912
- cur_beam_size);
913
- }
914
- }
915
-
916
- if (out_codes) {
917
- // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
918
- memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
919
- }
920
- if (out_residuals) {
921
- // memcpy(out_residuals,
922
- // residuals.data(),
923
- // residuals.size() * sizeof(residuals[0]));
924
- memcpy(out_residuals,
925
- residuals_ptr,
926
- residuals_size * sizeof(*residuals_ptr));
927
- }
928
- if (out_distances) {
929
- // memcpy(out_distances,
930
- // distances.data(),
931
- // distances.size() * sizeof(distances[0]));
932
- memcpy(out_distances,
933
- pool.distances.data(),
934
- distances_size * sizeof(pool.distances[0]));
935
- }
936
- }
937
-
938
470
  void ResidualQuantizer::refine_beam(
939
471
  size_t n,
940
472
  size_t beam_size,
@@ -961,556 +493,36 @@ void ResidualQuantizer::refine_beam(
961
493
  *******************************************************************/
962
494
 
963
495
  void ResidualQuantizer::compute_codebook_tables() {
964
- codebook_cross_products.resize(total_codebook_size * total_codebook_size);
965
496
  cent_norms.resize(total_codebook_size);
966
- // stricly speaking we could use ssyrk
967
- {
968
- FINTEGER ni = total_codebook_size;
497
+ fvec_norms_L2sqr(
498
+ cent_norms.data(), codebooks.data(), d, total_codebook_size);
499
+ size_t cross_table_size = 0;
500
+ for (int m = 0; m < M; m++) {
501
+ size_t K = (size_t)1 << nbits[m];
502
+ cross_table_size += K * codebook_offsets[m];
503
+ }
504
+ codebook_cross_products.resize(cross_table_size);
505
+ size_t ofs = 0;
506
+ for (int m = 1; m < M; m++) {
507
+ FINTEGER ki = (size_t)1 << nbits[m];
508
+ FINTEGER kk = codebook_offsets[m];
969
509
  FINTEGER di = d;
970
510
  float zero = 0, one = 1;
511
+ assert(ofs + ki * kk <= cross_table_size);
971
512
  sgemm_("Transposed",
972
513
  "Not transposed",
973
- &ni,
974
- &ni,
514
+ &ki,
515
+ &kk,
975
516
  &di,
976
517
  &one,
977
- codebooks.data(),
518
+ codebooks.data() + d * kk,
978
519
  &di,
979
520
  codebooks.data(),
980
521
  &di,
981
522
  &zero,
982
- codebook_cross_products.data(),
983
- &ni);
984
- }
985
- for (size_t i = 0; i < total_codebook_size; i++) {
986
- cent_norms[i] = codebook_cross_products[i + i * total_codebook_size];
987
- }
988
- }
989
-
990
- namespace {
991
-
992
- template <size_t M, size_t NK>
993
- void accum_and_store_tab(
994
- const size_t m_offset,
995
- const float* const __restrict codebook_cross_norms,
996
- const uint64_t* const __restrict codebook_offsets,
997
- const int32_t* const __restrict codes_i,
998
- const size_t b,
999
- const size_t ldc,
1000
- const size_t K,
1001
- float* const __restrict output) {
1002
- // load pointers into registers
1003
- const float* cbs[M];
1004
- for (size_t ij = 0; ij < M; ij++) {
1005
- const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
1006
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1007
- }
1008
-
1009
- // do accumulation in registers using SIMD.
1010
- // It is possible that compiler may be smart enough so that
1011
- // this manual SIMD unrolling might be unneeded.
1012
- #if defined(__AVX2__) || defined(__aarch64__)
1013
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
1014
-
1015
- // process in chunks of size (8 * NK) floats
1016
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1017
- simd8float32 regs[NK];
1018
- for (size_t ik = 0; ik < NK; ik++) {
1019
- regs[ik].loadu(cbs[0] + kk + ik * 8);
1020
- }
1021
-
1022
- for (size_t ij = 1; ij < M; ij++) {
1023
- for (size_t ik = 0; ik < NK; ik++) {
1024
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1025
- }
1026
- }
1027
-
1028
- // write the result
1029
- for (size_t ik = 0; ik < NK; ik++) {
1030
- regs[ik].storeu(output + kk + ik * 8);
1031
- }
1032
- }
1033
- #else
1034
- const size_t K8 = 0;
1035
- #endif
1036
-
1037
- // process leftovers
1038
- for (size_t kk = K8; kk < K; kk++) {
1039
- float reg = cbs[0][kk];
1040
- for (size_t ij = 1; ij < M; ij++) {
1041
- reg += cbs[ij][kk];
1042
- }
1043
- output[b * K + kk] = reg;
1044
- }
1045
- }
1046
-
1047
- template <size_t M, size_t NK>
1048
- void accum_and_add_tab(
1049
- const size_t m_offset,
1050
- const float* const __restrict codebook_cross_norms,
1051
- const uint64_t* const __restrict codebook_offsets,
1052
- const int32_t* const __restrict codes_i,
1053
- const size_t b,
1054
- const size_t ldc,
1055
- const size_t K,
1056
- float* const __restrict output) {
1057
- // load pointers into registers
1058
- const float* cbs[M];
1059
- for (size_t ij = 0; ij < M; ij++) {
1060
- const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
1061
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1062
- }
1063
-
1064
- // do accumulation in registers using SIMD.
1065
- // It is possible that compiler may be smart enough so that
1066
- // this manual SIMD unrolling might be unneeded.
1067
- #if defined(__AVX2__) || defined(__aarch64__)
1068
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
1069
-
1070
- // process in chunks of size (8 * NK) floats
1071
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1072
- simd8float32 regs[NK];
1073
- for (size_t ik = 0; ik < NK; ik++) {
1074
- regs[ik].loadu(cbs[0] + kk + ik * 8);
1075
- }
1076
-
1077
- for (size_t ij = 1; ij < M; ij++) {
1078
- for (size_t ik = 0; ik < NK; ik++) {
1079
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1080
- }
1081
- }
1082
-
1083
- // write the result
1084
- for (size_t ik = 0; ik < NK; ik++) {
1085
- simd8float32 existing(output + kk + ik * 8);
1086
- existing += regs[ik];
1087
- existing.storeu(output + kk + ik * 8);
1088
- }
1089
- }
1090
- #else
1091
- const size_t K8 = 0;
1092
- #endif
1093
-
1094
- // process leftovers
1095
- for (size_t kk = K8; kk < K; kk++) {
1096
- float reg = cbs[0][kk];
1097
- for (size_t ij = 1; ij < M; ij++) {
1098
- reg += cbs[ij][kk];
1099
- }
1100
- output[b * K + kk] += reg;
1101
- }
1102
- }
1103
-
1104
- template <size_t M, size_t NK>
1105
- void accum_and_finalize_tab(
1106
- const float* const __restrict codebook_cross_norms,
1107
- const uint64_t* const __restrict codebook_offsets,
1108
- const int32_t* const __restrict codes_i,
1109
- const size_t b,
1110
- const size_t ldc,
1111
- const size_t K,
1112
- const float* const __restrict distances_i,
1113
- const float* const __restrict cd_common,
1114
- float* const __restrict output) {
1115
- // load pointers into registers
1116
- const float* cbs[M];
1117
- for (size_t ij = 0; ij < M; ij++) {
1118
- const size_t code = static_cast<size_t>(codes_i[b * M + ij]);
1119
- cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
1120
- }
1121
-
1122
- // do accumulation in registers using SIMD.
1123
- // It is possible that compiler may be smart enough so that
1124
- // this manual SIMD unrolling might be unneeded.
1125
- #if defined(__AVX2__) || defined(__aarch64__)
1126
- const size_t K8 = (K / (8 * NK)) * (8 * NK);
1127
-
1128
- // process in chunks of size (8 * NK) floats
1129
- for (size_t kk = 0; kk < K8; kk += 8 * NK) {
1130
- simd8float32 regs[NK];
1131
- for (size_t ik = 0; ik < NK; ik++) {
1132
- regs[ik].loadu(cbs[0] + kk + ik * 8);
1133
- }
1134
-
1135
- for (size_t ij = 1; ij < M; ij++) {
1136
- for (size_t ik = 0; ik < NK; ik++) {
1137
- regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
1138
- }
1139
- }
1140
-
1141
- simd8float32 two(2.0f);
1142
- for (size_t ik = 0; ik < NK; ik++) {
1143
- // cent_distances[b * K + k] = distances_i[b] + cd_common[k]
1144
- // + 2 * dp[k];
1145
-
1146
- simd8float32 common_v(cd_common + kk + ik * 8);
1147
- common_v = fmadd(two, regs[ik], common_v);
1148
-
1149
- common_v += simd8float32(distances_i[b]);
1150
- common_v.storeu(output + b * K + kk + ik * 8);
1151
- }
1152
- }
1153
- #else
1154
- const size_t K8 = 0;
1155
- #endif
1156
-
1157
- // process leftovers
1158
- for (size_t kk = K8; kk < K; kk++) {
1159
- float reg = cbs[0][kk];
1160
- for (size_t ij = 1; ij < M; ij++) {
1161
- reg += cbs[ij][kk];
1162
- }
1163
-
1164
- output[b * K + kk] = distances_i[b] + cd_common[kk] + 2 * reg;
1165
- }
1166
- }
1167
-
1168
- } // namespace
1169
-
1170
- void beam_search_encode_step_tab(
1171
- size_t K,
1172
- size_t n,
1173
- size_t beam_size, // input sizes
1174
- const float* codebook_cross_norms, // size K * ldc
1175
- size_t ldc, // >= K
1176
- const uint64_t* codebook_offsets, // m
1177
- const float* query_cp, // size n * ldqc
1178
- size_t ldqc, // >= K
1179
- const float* cent_norms_i, // size K
1180
- size_t m,
1181
- const int32_t* codes, // n * beam_size * m
1182
- const float* distances, // n * beam_size
1183
- size_t new_beam_size,
1184
- int32_t* new_codes, // n * new_beam_size * (m + 1)
1185
- float* new_distances, // n * new_beam_size
1186
- ApproxTopK_mode_t approx_topk_mode) //
1187
- {
1188
- FAISS_THROW_IF_NOT(ldc >= K);
1189
-
1190
- #pragma omp parallel for if (n > 100) schedule(dynamic)
1191
- for (int64_t i = 0; i < n; i++) {
1192
- std::vector<float> cent_distances(beam_size * K);
1193
- std::vector<float> cd_common(K);
1194
-
1195
- const int32_t* codes_i = codes + i * m * beam_size;
1196
- const float* query_cp_i = query_cp + i * ldqc;
1197
- const float* distances_i = distances + i * beam_size;
1198
-
1199
- for (size_t k = 0; k < K; k++) {
1200
- cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
1201
- }
1202
-
1203
- /*
1204
- // This is the baseline implementation. Its primary flaw
1205
- // that it writes way too many info to the temporary buffer
1206
- // called dp.
1207
- //
1208
- // This baseline code is kept intentionally because it is easy to
1209
- // understand what an optimized version optimizes exactly.
1210
- //
1211
- for (size_t b = 0; b < beam_size; b++) {
1212
- std::vector<float> dp(K);
1213
-
1214
- for (size_t m1 = 0; m1 < m; m1++) {
1215
- size_t c = codes_i[b * m + m1];
1216
- const float* cb =
1217
- &codebook_cross_norms[(codebook_offsets[m1] + c) * ldc];
1218
- fvec_add(K, cb, dp.data(), dp.data());
1219
- }
1220
-
1221
- for (size_t k = 0; k < K; k++) {
1222
- cent_distances[b * K + k] =
1223
- distances_i[b] + cd_common[k] + 2 * dp[k];
1224
- }
1225
- }
1226
- */
1227
-
1228
- // An optimized implementation that avoids using a temporary buffer
1229
- // and does the accumulation in registers.
1230
-
1231
- // Compute a sum of NK AQ codes.
1232
- #define ACCUM_AND_FINALIZE_TAB(NK) \
1233
- case NK: \
1234
- for (size_t b = 0; b < beam_size; b++) { \
1235
- accum_and_finalize_tab<NK, 4>( \
1236
- codebook_cross_norms, \
1237
- codebook_offsets, \
1238
- codes_i, \
1239
- b, \
1240
- ldc, \
1241
- K, \
1242
- distances_i, \
1243
- cd_common.data(), \
1244
- cent_distances.data()); \
1245
- } \
1246
- break;
1247
-
1248
- // this version contains many switch-case scenarios, but
1249
- // they won't affect branch predictor.
1250
- switch (m) {
1251
- case 0:
1252
- // trivial case
1253
- for (size_t b = 0; b < beam_size; b++) {
1254
- for (size_t k = 0; k < K; k++) {
1255
- cent_distances[b * K + k] =
1256
- distances_i[b] + cd_common[k];
1257
- }
1258
- }
1259
- break;
1260
-
1261
- ACCUM_AND_FINALIZE_TAB(1)
1262
- ACCUM_AND_FINALIZE_TAB(2)
1263
- ACCUM_AND_FINALIZE_TAB(3)
1264
- ACCUM_AND_FINALIZE_TAB(4)
1265
- ACCUM_AND_FINALIZE_TAB(5)
1266
- ACCUM_AND_FINALIZE_TAB(6)
1267
- ACCUM_AND_FINALIZE_TAB(7)
1268
-
1269
- default: {
1270
- // m >= 8 case.
1271
-
1272
- // A temporary buffer has to be used due to the lack of
1273
- // registers. But we'll try to accumulate up to 8 AQ codes in
1274
- // registers and issue a single write operation to the buffer,
1275
- // while the baseline does no accumulation. So, the number of
1276
- // write operations to the temporary buffer is reduced 8x.
1277
-
1278
- // allocate a temporary buffer
1279
- std::vector<float> dp(K);
1280
-
1281
- for (size_t b = 0; b < beam_size; b++) {
1282
- // Initialize it. Compute a sum of first 8 AQ codes
1283
- // because m >= 8 .
1284
- accum_and_store_tab<8, 4>(
1285
- m,
1286
- codebook_cross_norms,
1287
- codebook_offsets,
1288
- codes_i,
1289
- b,
1290
- ldc,
1291
- K,
1292
- dp.data());
1293
-
1294
- #define ACCUM_AND_ADD_TAB(NK) \
1295
- case NK: \
1296
- accum_and_add_tab<NK, 4>( \
1297
- m, \
1298
- codebook_cross_norms, \
1299
- codebook_offsets + im, \
1300
- codes_i + im, \
1301
- b, \
1302
- ldc, \
1303
- K, \
1304
- dp.data()); \
1305
- break;
1306
-
1307
- // accumulate up to 8 additional AQ codes into
1308
- // a temporary buffer
1309
- for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
1310
- size_t m_left = m - im;
1311
- if (m_left > 8) {
1312
- m_left = 8;
1313
- }
1314
-
1315
- switch (m_left) {
1316
- ACCUM_AND_ADD_TAB(1)
1317
- ACCUM_AND_ADD_TAB(2)
1318
- ACCUM_AND_ADD_TAB(3)
1319
- ACCUM_AND_ADD_TAB(4)
1320
- ACCUM_AND_ADD_TAB(5)
1321
- ACCUM_AND_ADD_TAB(6)
1322
- ACCUM_AND_ADD_TAB(7)
1323
- ACCUM_AND_ADD_TAB(8)
1324
- }
1325
- }
1326
-
1327
- // done. finalize the result
1328
- for (size_t k = 0; k < K; k++) {
1329
- cent_distances[b * K + k] =
1330
- distances_i[b] + cd_common[k] + 2 * dp[k];
1331
- }
1332
- }
1333
- }
1334
- }
1335
-
1336
- // the optimized implementation ends here
1337
-
1338
- using C = CMax<float, int>;
1339
- int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
1340
- float* new_distances_i = new_distances + i * new_beam_size;
1341
-
1342
- const float* cent_distances_i = cent_distances.data();
1343
-
1344
- // then we have to select the best results
1345
- for (int i = 0; i < new_beam_size; i++) {
1346
- new_distances_i[i] = C::neutral();
1347
- }
1348
- std::vector<int> perm(new_beam_size, -1);
1349
-
1350
- #define HANDLE_APPROX(NB, BD) \
1351
- case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
1352
- HeapWithBuckets<C, NB, BD>::bs_addn( \
1353
- beam_size, \
1354
- K, \
1355
- cent_distances_i, \
1356
- new_beam_size, \
1357
- new_distances_i, \
1358
- perm.data()); \
1359
- break;
1360
-
1361
- switch (approx_topk_mode) {
1362
- HANDLE_APPROX(8, 3)
1363
- HANDLE_APPROX(8, 2)
1364
- HANDLE_APPROX(16, 2)
1365
- HANDLE_APPROX(32, 2)
1366
- default:
1367
- heap_addn<C>(
1368
- new_beam_size,
1369
- new_distances_i,
1370
- perm.data(),
1371
- cent_distances_i,
1372
- nullptr,
1373
- beam_size * K);
1374
- break;
1375
- }
1376
-
1377
- heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
1378
-
1379
- #undef HANDLE_APPROX
1380
-
1381
- for (int j = 0; j < new_beam_size; j++) {
1382
- int js = perm[j] / K;
1383
- int ls = perm[j] % K;
1384
- if (m > 0) {
1385
- memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
1386
- }
1387
- new_codes_i[m] = ls;
1388
- new_codes_i += m + 1;
1389
- }
1390
- }
1391
- }
1392
-
1393
- //
1394
- void refine_beam_LUT_mp(
1395
- const ResidualQuantizer& rq,
1396
- size_t n,
1397
- const float* query_norms, // size n
1398
- const float* query_cp, //
1399
- int out_beam_size,
1400
- int32_t* out_codes,
1401
- float* out_distances,
1402
- RefineBeamLUTMemoryPool& pool) {
1403
- int beam_size = 1;
1404
-
1405
- double t0 = getmillisecs();
1406
-
1407
- // find the max_beam_size
1408
- int max_beam_size = 0;
1409
- {
1410
- int tmp_beam_size = beam_size;
1411
- for (int m = 0; m < rq.M; m++) {
1412
- int K = 1 << rq.nbits[m];
1413
- int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
1414
- tmp_beam_size = new_beam_size;
1415
-
1416
- if (max_beam_size < new_beam_size) {
1417
- max_beam_size = new_beam_size;
1418
- }
1419
- }
1420
- }
1421
-
1422
- // preallocate buffers
1423
- pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
1424
- pool.new_distances.resize(n * max_beam_size);
1425
-
1426
- pool.codes.resize(n * max_beam_size * (rq.M + 1));
1427
- pool.distances.resize(n * max_beam_size);
1428
-
1429
- for (size_t i = 0; i < n; i++) {
1430
- pool.distances[i] = query_norms[i];
1431
- }
1432
-
1433
- // set up pointers to buffers
1434
- int32_t* __restrict new_codes_ptr = pool.new_codes.data();
1435
- float* __restrict new_distances_ptr = pool.new_distances.data();
1436
-
1437
- int32_t* __restrict codes_ptr = pool.codes.data();
1438
- float* __restrict distances_ptr = pool.distances.data();
1439
-
1440
- // main loop
1441
- size_t codes_size = 0;
1442
- size_t distances_size = 0;
1443
- for (int m = 0; m < rq.M; m++) {
1444
- int K = 1 << rq.nbits[m];
1445
-
1446
- // it is guaranteed that (new_beam_size <= than max_beam_size) ==
1447
- // true
1448
- int new_beam_size = std::min(beam_size * K, out_beam_size);
1449
-
1450
- // std::vector<int32_t> new_codes(n * new_beam_size * (m + 1));
1451
- // std::vector<float> new_distances(n * new_beam_size);
1452
-
1453
- codes_size = n * new_beam_size * (m + 1);
1454
- distances_size = n * new_beam_size;
1455
-
1456
- beam_search_encode_step_tab(
1457
- K,
1458
- n,
1459
- beam_size,
1460
- rq.codebook_cross_products.data() + rq.codebook_offsets[m],
1461
- rq.total_codebook_size,
1462
- rq.codebook_offsets.data(),
1463
- query_cp + rq.codebook_offsets[m],
1464
- rq.total_codebook_size,
1465
- rq.cent_norms.data() + rq.codebook_offsets[m],
1466
- m,
1467
- // codes.data(),
1468
- codes_ptr,
1469
- // distances.data(),
1470
- distances_ptr,
1471
- new_beam_size,
1472
- // new_codes.data(),
1473
- new_codes_ptr,
1474
- // new_distances.data()
1475
- new_distances_ptr,
1476
- rq.approx_topk_mode);
1477
-
1478
- // codes.swap(new_codes);
1479
- std::swap(codes_ptr, new_codes_ptr);
1480
- // distances.swap(new_distances);
1481
- std::swap(distances_ptr, new_distances_ptr);
1482
-
1483
- beam_size = new_beam_size;
1484
-
1485
- if (rq.verbose) {
1486
- float sum_distances = 0;
1487
- // for (int j = 0; j < distances.size(); j++) {
1488
- // sum_distances += distances[j];
1489
- // }
1490
- for (int j = 0; j < distances_size; j++) {
1491
- sum_distances += distances_ptr[j];
1492
- }
1493
- printf("[%.3f s] encode stage %d, %d bits, "
1494
- "total error %g, beam_size %d\n",
1495
- (getmillisecs() - t0) / 1000,
1496
- m,
1497
- int(rq.nbits[m]),
1498
- sum_distances,
1499
- beam_size);
1500
- }
1501
- }
1502
-
1503
- if (out_codes) {
1504
- // memcpy(out_codes, codes.data(), codes.size() * sizeof(codes[0]));
1505
- memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
1506
- }
1507
- if (out_distances) {
1508
- // memcpy(out_distances,
1509
- // distances.data(),
1510
- // distances.size() * sizeof(distances[0]));
1511
- memcpy(out_distances,
1512
- distances_ptr,
1513
- distances_size * sizeof(*distances_ptr));
523
+ codebook_cross_products.data() + ofs,
524
+ &ki);
525
+ ofs += ki * kk;
1514
526
  }
1515
527
  }
1516
528