faiss 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (171) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +5 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +0 -1
  12. data/vendor/faiss/faiss/Clustering.cpp +4 -18
  13. data/vendor/faiss/faiss/Clustering.h +31 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +22 -11
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +20 -5
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  25. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +106 -187
  26. data/vendor/faiss/faiss/IndexFastScan.cpp +90 -159
  27. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  28. data/vendor/faiss/faiss/IndexFlat.cpp +195 -3
  29. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  30. data/vendor/faiss/faiss/IndexFlatCodes.cpp +11 -0
  31. data/vendor/faiss/faiss/IndexFlatCodes.h +3 -1
  32. data/vendor/faiss/faiss/IndexHNSW.cpp +112 -316
  33. data/vendor/faiss/faiss/IndexHNSW.h +12 -48
  34. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  35. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  36. data/vendor/faiss/faiss/IndexIVF.cpp +159 -53
  37. data/vendor/faiss/faiss/IndexIVF.h +37 -5
  38. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +18 -26
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +3 -2
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  42. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +433 -405
  43. data/vendor/faiss/faiss/IndexIVFFastScan.h +56 -26
  44. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  45. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  46. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  48. data/vendor/faiss/faiss/IndexIVFPQ.cpp +78 -122
  49. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  50. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +18 -50
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  52. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  53. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  54. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  56. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  57. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -4
  58. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  59. data/vendor/faiss/faiss/IndexNSG.h +10 -10
  60. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  61. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  62. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  63. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  64. data/vendor/faiss/faiss/IndexRefine.cpp +49 -19
  65. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  66. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  67. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +22 -16
  68. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  69. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  70. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  71. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  72. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  73. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  74. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  75. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  76. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  77. data/vendor/faiss/faiss/clone_index.h +3 -0
  78. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +87 -4
  79. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  80. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +7 -0
  81. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  82. data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
  83. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  84. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +8 -9
  85. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +18 -3
  86. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  87. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  88. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  89. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  90. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +117 -17
  91. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  92. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +1 -1
  93. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  94. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  95. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +267 -40
  96. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  97. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  98. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  99. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  100. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  101. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +1 -1
  102. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +1 -2
  103. data/vendor/faiss/faiss/impl/DistanceComputer.h +24 -1
  104. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  105. data/vendor/faiss/faiss/impl/HNSW.cpp +321 -70
  106. data/vendor/faiss/faiss/impl/HNSW.h +9 -8
  107. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  108. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +3 -1
  109. data/vendor/faiss/faiss/impl/NNDescent.cpp +29 -19
  110. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  111. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  112. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  113. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +24 -22
  114. data/vendor/faiss/faiss/impl/ProductQuantizer.h +1 -1
  115. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  116. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  117. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  118. data/vendor/faiss/faiss/impl/ResultHandler.h +232 -176
  119. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +444 -104
  120. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +0 -8
  121. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +280 -42
  122. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  123. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  124. data/vendor/faiss/faiss/impl/index_read.cpp +45 -19
  125. data/vendor/faiss/faiss/impl/index_write.cpp +60 -41
  126. data/vendor/faiss/faiss/impl/io.cpp +10 -10
  127. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  128. data/vendor/faiss/faiss/impl/platform_macros.h +18 -1
  129. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +3 -0
  130. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  131. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  132. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +40 -49
  133. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  134. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  135. data/vendor/faiss/faiss/impl/simd_result_handlers.h +374 -202
  136. data/vendor/faiss/faiss/index_factory.cpp +10 -7
  137. data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -1
  138. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +27 -9
  139. data/vendor/faiss/faiss/invlists/InvertedLists.h +12 -3
  140. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +3 -3
  141. data/vendor/faiss/faiss/python/python_callbacks.cpp +1 -1
  142. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  143. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  144. data/vendor/faiss/faiss/utils/distances.cpp +128 -74
  145. data/vendor/faiss/faiss/utils/distances.h +81 -4
  146. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  147. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  148. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  149. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  150. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  151. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  152. data/vendor/faiss/faiss/utils/distances_simd.cpp +428 -70
  153. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  154. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  155. data/vendor/faiss/faiss/utils/hamming.cpp +162 -110
  156. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  157. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  158. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  159. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +15 -87
  160. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +57 -0
  161. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  162. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  163. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  164. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  165. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  166. data/vendor/faiss/faiss/utils/simdlib_neon.h +72 -77
  167. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  168. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  169. data/vendor/faiss/faiss/utils/utils.cpp +112 -6
  170. data/vendor/faiss/faiss/utils/utils.h +57 -20
  171. metadata +10 -3
@@ -0,0 +1,960 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/residual_quantizer_encode_steps.h>
9
+
10
+ #include <faiss/impl/AuxIndexStructures.h>
11
+ #include <faiss/impl/FaissAssert.h>
12
+ #include <faiss/impl/ResidualQuantizer.h>
13
+ #include <faiss/utils/Heap.h>
14
+ #include <faiss/utils/distances.h>
15
+ #include <faiss/utils/simdlib.h>
16
+ #include <faiss/utils/utils.h>
17
+
18
+ #include <faiss/utils/approx_topk/approx_topk.h>
19
+
20
+ extern "C" {
21
+
22
+ // general matrix multiplication
23
+ int sgemm_(
24
+ const char* transa,
25
+ const char* transb,
26
+ FINTEGER* m,
27
+ FINTEGER* n,
28
+ FINTEGER* k,
29
+ const float* alpha,
30
+ const float* a,
31
+ FINTEGER* lda,
32
+ const float* b,
33
+ FINTEGER* ldb,
34
+ float* beta,
35
+ float* c,
36
+ FINTEGER* ldc);
37
+ }
38
+
39
+ namespace faiss {
40
+
41
+ /********************************************************************
42
+ * Basic routines
43
+ ********************************************************************/
44
+
45
+ namespace {
46
+
47
+ template <size_t M, size_t NK>
48
+ void accum_and_store_tab(
49
+ const size_t m_offset,
50
+ const float* const __restrict codebook_cross_norms,
51
+ const uint64_t* const __restrict codebook_offsets,
52
+ const int32_t* const __restrict codes_i,
53
+ const size_t b,
54
+ const size_t ldc,
55
+ const size_t K,
56
+ float* const __restrict output) {
57
+ // load pointers into registers
58
+ const float* cbs[M];
59
+ for (size_t ij = 0; ij < M; ij++) {
60
+ const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
61
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
62
+ }
63
+
64
+ // do accumulation in registers using SIMD.
65
+ // It is possible that compiler may be smart enough so that
66
+ // this manual SIMD unrolling might be unneeded.
67
+ #if defined(__AVX2__) || defined(__aarch64__)
68
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
69
+
70
+ // process in chunks of size (8 * NK) floats
71
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
72
+ simd8float32 regs[NK];
73
+ for (size_t ik = 0; ik < NK; ik++) {
74
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
75
+ }
76
+
77
+ for (size_t ij = 1; ij < M; ij++) {
78
+ for (size_t ik = 0; ik < NK; ik++) {
79
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
80
+ }
81
+ }
82
+
83
+ // write the result
84
+ for (size_t ik = 0; ik < NK; ik++) {
85
+ regs[ik].storeu(output + kk + ik * 8);
86
+ }
87
+ }
88
+ #else
89
+ const size_t K8 = 0;
90
+ #endif
91
+
92
+ // process leftovers
93
+ for (size_t kk = K8; kk < K; kk++) {
94
+ float reg = cbs[0][kk];
95
+ for (size_t ij = 1; ij < M; ij++) {
96
+ reg += cbs[ij][kk];
97
+ }
98
+ output[b * K + kk] = reg;
99
+ }
100
+ }
101
+
102
+ template <size_t M, size_t NK>
103
+ void accum_and_add_tab(
104
+ const size_t m_offset,
105
+ const float* const __restrict codebook_cross_norms,
106
+ const uint64_t* const __restrict codebook_offsets,
107
+ const int32_t* const __restrict codes_i,
108
+ const size_t b,
109
+ const size_t ldc,
110
+ const size_t K,
111
+ float* const __restrict output) {
112
+ // load pointers into registers
113
+ const float* cbs[M];
114
+ for (size_t ij = 0; ij < M; ij++) {
115
+ const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
116
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
117
+ }
118
+
119
+ // do accumulation in registers using SIMD.
120
+ // It is possible that compiler may be smart enough so that
121
+ // this manual SIMD unrolling might be unneeded.
122
+ #if defined(__AVX2__) || defined(__aarch64__)
123
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
124
+
125
+ // process in chunks of size (8 * NK) floats
126
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
127
+ simd8float32 regs[NK];
128
+ for (size_t ik = 0; ik < NK; ik++) {
129
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
130
+ }
131
+
132
+ for (size_t ij = 1; ij < M; ij++) {
133
+ for (size_t ik = 0; ik < NK; ik++) {
134
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
135
+ }
136
+ }
137
+
138
+ // write the result
139
+ for (size_t ik = 0; ik < NK; ik++) {
140
+ simd8float32 existing(output + kk + ik * 8);
141
+ existing += regs[ik];
142
+ existing.storeu(output + kk + ik * 8);
143
+ }
144
+ }
145
+ #else
146
+ const size_t K8 = 0;
147
+ #endif
148
+
149
+ // process leftovers
150
+ for (size_t kk = K8; kk < K; kk++) {
151
+ float reg = cbs[0][kk];
152
+ for (size_t ij = 1; ij < M; ij++) {
153
+ reg += cbs[ij][kk];
154
+ }
155
+ output[b * K + kk] += reg;
156
+ }
157
+ }
158
+
159
+ template <size_t M, size_t NK>
160
+ void accum_and_finalize_tab(
161
+ const float* const __restrict codebook_cross_norms,
162
+ const uint64_t* const __restrict codebook_offsets,
163
+ const int32_t* const __restrict codes_i,
164
+ const size_t b,
165
+ const size_t ldc,
166
+ const size_t K,
167
+ const float* const __restrict distances_i,
168
+ const float* const __restrict cd_common,
169
+ float* const __restrict output) {
170
+ // load pointers into registers
171
+ const float* cbs[M];
172
+ for (size_t ij = 0; ij < M; ij++) {
173
+ const size_t code = static_cast<size_t>(codes_i[b * M + ij]);
174
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
175
+ }
176
+
177
+ // do accumulation in registers using SIMD.
178
+ // It is possible that compiler may be smart enough so that
179
+ // this manual SIMD unrolling might be unneeded.
180
+ #if defined(__AVX2__) || defined(__aarch64__)
181
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
182
+
183
+ // process in chunks of size (8 * NK) floats
184
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
185
+ simd8float32 regs[NK];
186
+ for (size_t ik = 0; ik < NK; ik++) {
187
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
188
+ }
189
+
190
+ for (size_t ij = 1; ij < M; ij++) {
191
+ for (size_t ik = 0; ik < NK; ik++) {
192
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
193
+ }
194
+ }
195
+
196
+ simd8float32 two(2.0f);
197
+ for (size_t ik = 0; ik < NK; ik++) {
198
+ // cent_distances[b * K + k] = distances_i[b] + cd_common[k]
199
+ // + 2 * dp[k];
200
+
201
+ simd8float32 common_v(cd_common + kk + ik * 8);
202
+ common_v = fmadd(two, regs[ik], common_v);
203
+
204
+ common_v += simd8float32(distances_i[b]);
205
+ common_v.storeu(output + b * K + kk + ik * 8);
206
+ }
207
+ }
208
+ #else
209
+ const size_t K8 = 0;
210
+ #endif
211
+
212
+ // process leftovers
213
+ for (size_t kk = K8; kk < K; kk++) {
214
+ float reg = cbs[0][kk];
215
+ for (size_t ij = 1; ij < M; ij++) {
216
+ reg += cbs[ij][kk];
217
+ }
218
+
219
+ output[b * K + kk] = distances_i[b] + cd_common[kk] + 2 * reg;
220
+ }
221
+ }
222
+
223
+ } // anonymous namespace
224
+
225
+ /********************************************************************
226
+ * Single encoding step
227
+ ********************************************************************/
228
+
229
+ void beam_search_encode_step(
230
+ size_t d,
231
+ size_t K,
232
+ const float* cent, /// size (K, d)
233
+ size_t n,
234
+ size_t beam_size,
235
+ const float* residuals, /// size (n, beam_size, d)
236
+ size_t m,
237
+ const int32_t* codes, /// size (n, beam_size, m)
238
+ size_t new_beam_size,
239
+ int32_t* new_codes, /// size (n, new_beam_size, m + 1)
240
+ float* new_residuals, /// size (n, new_beam_size, d)
241
+ float* new_distances, /// size (n, new_beam_size)
242
+ Index* assign_index,
243
+ ApproxTopK_mode_t approx_topk_mode) {
244
+ // we have to fill in the whole output matrix
245
+ FAISS_THROW_IF_NOT(new_beam_size <= beam_size * K);
246
+
247
+ std::vector<float> cent_distances;
248
+ std::vector<idx_t> cent_ids;
249
+
250
+ if (assign_index) {
251
+ // search beam_size distances per query
252
+ FAISS_THROW_IF_NOT(assign_index->d == d);
253
+ cent_distances.resize(n * beam_size * new_beam_size);
254
+ cent_ids.resize(n * beam_size * new_beam_size);
255
+ if (assign_index->ntotal != 0) {
256
+ // then we assume the codebooks are already added to the index
257
+ FAISS_THROW_IF_NOT(assign_index->ntotal == K);
258
+ } else {
259
+ assign_index->add(K, cent);
260
+ }
261
+
262
+ // printf("beam_search_encode_step -- mem usage %zd\n",
263
+ // get_mem_usage_kb());
264
+ assign_index->search(
265
+ n * beam_size,
266
+ residuals,
267
+ new_beam_size,
268
+ cent_distances.data(),
269
+ cent_ids.data());
270
+ } else {
271
+ // do one big distance computation
272
+ cent_distances.resize(n * beam_size * K);
273
+ pairwise_L2sqr(
274
+ d, n * beam_size, residuals, K, cent, cent_distances.data());
275
+ }
276
+ InterruptCallback::check();
277
+
278
+ #pragma omp parallel for if (n > 100)
279
+ for (int64_t i = 0; i < n; i++) {
280
+ const int32_t* codes_i = codes + i * m * beam_size;
281
+ int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
282
+ const float* residuals_i = residuals + i * d * beam_size;
283
+ float* new_residuals_i = new_residuals + i * d * new_beam_size;
284
+
285
+ float* new_distances_i = new_distances + i * new_beam_size;
286
+ using C = CMax<float, int>;
287
+
288
+ if (assign_index) {
289
+ const float* cent_distances_i =
290
+ cent_distances.data() + i * beam_size * new_beam_size;
291
+ const idx_t* cent_ids_i =
292
+ cent_ids.data() + i * beam_size * new_beam_size;
293
+
294
+ // here we could be a tad more efficient by merging sorted arrays
295
+ for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
296
+ new_distances_i[i_2] = C::neutral();
297
+ }
298
+ std::vector<int> perm(new_beam_size, -1);
299
+ heap_addn<C>(
300
+ new_beam_size,
301
+ new_distances_i,
302
+ perm.data(),
303
+ cent_distances_i,
304
+ nullptr,
305
+ beam_size * new_beam_size);
306
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
307
+
308
+ for (int j = 0; j < new_beam_size; j++) {
309
+ int js = perm[j] / new_beam_size;
310
+ int ls = cent_ids_i[perm[j]];
311
+ if (m > 0) {
312
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
313
+ }
314
+ new_codes_i[m] = ls;
315
+ new_codes_i += m + 1;
316
+ fvec_sub(
317
+ d,
318
+ residuals_i + js * d,
319
+ cent + ls * d,
320
+ new_residuals_i);
321
+ new_residuals_i += d;
322
+ }
323
+
324
+ } else {
325
+ const float* cent_distances_i =
326
+ cent_distances.data() + i * beam_size * K;
327
+ // then we have to select the best results
328
+ for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
329
+ new_distances_i[i_2] = C::neutral();
330
+ }
331
+ std::vector<int> perm(new_beam_size, -1);
332
+
333
+ #define HANDLE_APPROX(NB, BD) \
334
+ case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
335
+ HeapWithBuckets<C, NB, BD>::bs_addn( \
336
+ beam_size, \
337
+ K, \
338
+ cent_distances_i, \
339
+ new_beam_size, \
340
+ new_distances_i, \
341
+ perm.data()); \
342
+ break;
343
+
344
+ switch (approx_topk_mode) {
345
+ HANDLE_APPROX(8, 3)
346
+ HANDLE_APPROX(8, 2)
347
+ HANDLE_APPROX(16, 2)
348
+ HANDLE_APPROX(32, 2)
349
+ default:
350
+ heap_addn<C>(
351
+ new_beam_size,
352
+ new_distances_i,
353
+ perm.data(),
354
+ cent_distances_i,
355
+ nullptr,
356
+ beam_size * K);
357
+ }
358
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
359
+
360
+ #undef HANDLE_APPROX
361
+
362
+ for (int j = 0; j < new_beam_size; j++) {
363
+ int js = perm[j] / K;
364
+ int ls = perm[j] % K;
365
+ if (m > 0) {
366
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
367
+ }
368
+ new_codes_i[m] = ls;
369
+ new_codes_i += m + 1;
370
+ fvec_sub(
371
+ d,
372
+ residuals_i + js * d,
373
+ cent + ls * d,
374
+ new_residuals_i);
375
+ new_residuals_i += d;
376
+ }
377
+ }
378
+ }
379
+ }
380
+
381
+ // exposed in the faiss namespace
382
+ void beam_search_encode_step_tab(
383
+ size_t K,
384
+ size_t n,
385
+ size_t beam_size, // input sizes
386
+ const float* codebook_cross_norms, // size K * ldc
387
+ size_t ldc,
388
+ const uint64_t* codebook_offsets, // m
389
+ const float* query_cp, // size n * ldqc
390
+ size_t ldqc, // >= K
391
+ const float* cent_norms_i, // size K
392
+ size_t m,
393
+ const int32_t* codes, // n * beam_size * m
394
+ const float* distances, // n * beam_size
395
+ size_t new_beam_size,
396
+ int32_t* new_codes, // n * new_beam_size * (m + 1)
397
+ float* new_distances, // n * new_beam_size
398
+ ApproxTopK_mode_t approx_topk_mode) //
399
+ {
400
+ FAISS_THROW_IF_NOT(ldc >= K);
401
+
402
+ #pragma omp parallel for if (n > 100) schedule(dynamic)
403
+ for (int64_t i = 0; i < n; i++) {
404
+ std::vector<float> cent_distances(beam_size * K);
405
+ std::vector<float> cd_common(K);
406
+
407
+ const int32_t* codes_i = codes + i * m * beam_size;
408
+ const float* query_cp_i = query_cp + i * ldqc;
409
+ const float* distances_i = distances + i * beam_size;
410
+
411
+ for (size_t k = 0; k < K; k++) {
412
+ cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
413
+ }
414
+
415
+ bool use_baseline_implementation = false;
416
+
417
+ // This is the baseline implementation. Its primary flaw
418
+ // that it writes way too many info to the temporary buffer
419
+ // called dp.
420
+ //
421
+ // This baseline code is kept intentionally because it is easy to
422
+ // understand what an optimized version optimizes exactly.
423
+ //
424
+ if (use_baseline_implementation) {
425
+ for (size_t b = 0; b < beam_size; b++) {
426
+ std::vector<float> dp(K);
427
+
428
+ for (size_t m1 = 0; m1 < m; m1++) {
429
+ size_t c = codes_i[b * m + m1];
430
+ const float* cb =
431
+ &codebook_cross_norms
432
+ [(codebook_offsets[m1] + c) * ldc];
433
+ fvec_add(K, cb, dp.data(), dp.data());
434
+ }
435
+
436
+ for (size_t k = 0; k < K; k++) {
437
+ cent_distances[b * K + k] =
438
+ distances_i[b] + cd_common[k] + 2 * dp[k];
439
+ }
440
+ }
441
+
442
+ } else {
443
+ // An optimized implementation that avoids using a temporary buffer
444
+ // and does the accumulation in registers.
445
+
446
+ // Compute a sum of NK AQ codes.
447
+ #define ACCUM_AND_FINALIZE_TAB(NK) \
448
+ case NK: \
449
+ for (size_t b = 0; b < beam_size; b++) { \
450
+ accum_and_finalize_tab<NK, 4>( \
451
+ codebook_cross_norms, \
452
+ codebook_offsets, \
453
+ codes_i, \
454
+ b, \
455
+ ldc, \
456
+ K, \
457
+ distances_i, \
458
+ cd_common.data(), \
459
+ cent_distances.data()); \
460
+ } \
461
+ break;
462
+
463
+ // this version contains many switch-case scenarios, but
464
+ // they won't affect branch predictor.
465
+ switch (m) {
466
+ case 0:
467
+ // trivial case
468
+ for (size_t b = 0; b < beam_size; b++) {
469
+ for (size_t k = 0; k < K; k++) {
470
+ cent_distances[b * K + k] =
471
+ distances_i[b] + cd_common[k];
472
+ }
473
+ }
474
+ break;
475
+
476
+ ACCUM_AND_FINALIZE_TAB(1)
477
+ ACCUM_AND_FINALIZE_TAB(2)
478
+ ACCUM_AND_FINALIZE_TAB(3)
479
+ ACCUM_AND_FINALIZE_TAB(4)
480
+ ACCUM_AND_FINALIZE_TAB(5)
481
+ ACCUM_AND_FINALIZE_TAB(6)
482
+ ACCUM_AND_FINALIZE_TAB(7)
483
+
484
+ default: {
485
+ // m >= 8 case.
486
+
487
+ // A temporary buffer has to be used due to the lack of
488
+ // registers. But we'll try to accumulate up to 8 AQ codes
489
+ // in registers and issue a single write operation to the
490
+ // buffer, while the baseline does no accumulation. So, the
491
+ // number of write operations to the temporary buffer is
492
+ // reduced 8x.
493
+
494
+ // allocate a temporary buffer
495
+ std::vector<float> dp(K);
496
+
497
+ for (size_t b = 0; b < beam_size; b++) {
498
+ // Initialize it. Compute a sum of first 8 AQ codes
499
+ // because m >= 8 .
500
+ accum_and_store_tab<8, 4>(
501
+ m,
502
+ codebook_cross_norms,
503
+ codebook_offsets,
504
+ codes_i,
505
+ b,
506
+ ldc,
507
+ K,
508
+ dp.data());
509
+
510
+ #define ACCUM_AND_ADD_TAB(NK) \
511
+ case NK: \
512
+ accum_and_add_tab<NK, 4>( \
513
+ m, \
514
+ codebook_cross_norms, \
515
+ codebook_offsets + im, \
516
+ codes_i + im, \
517
+ b, \
518
+ ldc, \
519
+ K, \
520
+ dp.data()); \
521
+ break;
522
+
523
+ // accumulate up to 8 additional AQ codes into
524
+ // a temporary buffer
525
+ for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
526
+ size_t m_left = m - im;
527
+ if (m_left > 8) {
528
+ m_left = 8;
529
+ }
530
+
531
+ switch (m_left) {
532
+ ACCUM_AND_ADD_TAB(1)
533
+ ACCUM_AND_ADD_TAB(2)
534
+ ACCUM_AND_ADD_TAB(3)
535
+ ACCUM_AND_ADD_TAB(4)
536
+ ACCUM_AND_ADD_TAB(5)
537
+ ACCUM_AND_ADD_TAB(6)
538
+ ACCUM_AND_ADD_TAB(7)
539
+ ACCUM_AND_ADD_TAB(8)
540
+ }
541
+ }
542
+
543
+ // done. finalize the result
544
+ for (size_t k = 0; k < K; k++) {
545
+ cent_distances[b * K + k] =
546
+ distances_i[b] + cd_common[k] + 2 * dp[k];
547
+ }
548
+ }
549
+ }
550
+ }
551
+
552
+ // the optimized implementation ends here
553
+ }
554
+ using C = CMax<float, int>;
555
+ int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
556
+ float* new_distances_i = new_distances + i * new_beam_size;
557
+
558
+ const float* cent_distances_i = cent_distances.data();
559
+
560
+ // then we have to select the best results
561
+ for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
562
+ new_distances_i[i_2] = C::neutral();
563
+ }
564
+ std::vector<int> perm(new_beam_size, -1);
565
+
566
+ #define HANDLE_APPROX(NB, BD) \
567
+ case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
568
+ HeapWithBuckets<C, NB, BD>::bs_addn( \
569
+ beam_size, \
570
+ K, \
571
+ cent_distances_i, \
572
+ new_beam_size, \
573
+ new_distances_i, \
574
+ perm.data()); \
575
+ break;
576
+
577
+ switch (approx_topk_mode) {
578
+ HANDLE_APPROX(8, 3)
579
+ HANDLE_APPROX(8, 2)
580
+ HANDLE_APPROX(16, 2)
581
+ HANDLE_APPROX(32, 2)
582
+ default:
583
+ heap_addn<C>(
584
+ new_beam_size,
585
+ new_distances_i,
586
+ perm.data(),
587
+ cent_distances_i,
588
+ nullptr,
589
+ beam_size * K);
590
+ break;
591
+ }
592
+
593
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
594
+
595
+ #undef HANDLE_APPROX
596
+
597
+ for (int j = 0; j < new_beam_size; j++) {
598
+ int js = perm[j] / K;
599
+ int ls = perm[j] % K;
600
+ if (m > 0) {
601
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
602
+ }
603
+ new_codes_i[m] = ls;
604
+ new_codes_i += m + 1;
605
+ }
606
+ }
607
+ }
608
+
609
+ /********************************************************************
610
+ * Multiple encoding steps
611
+ ********************************************************************/
612
+
613
+ namespace rq_encode_steps {
614
+
615
+ void refine_beam_mp(
616
+ const ResidualQuantizer& rq,
617
+ size_t n,
618
+ size_t beam_size,
619
+ const float* x,
620
+ int out_beam_size,
621
+ int32_t* out_codes,
622
+ float* out_residuals,
623
+ float* out_distances,
624
+ RefineBeamMemoryPool& pool) {
625
+ int cur_beam_size = beam_size;
626
+
627
+ double t0 = getmillisecs();
628
+
629
+ // find the max_beam_size
630
+ int max_beam_size = 0;
631
+ {
632
+ int tmp_beam_size = cur_beam_size;
633
+ for (int m = 0; m < rq.M; m++) {
634
+ int K = 1 << rq.nbits[m];
635
+ int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
636
+ tmp_beam_size = new_beam_size;
637
+
638
+ if (max_beam_size < new_beam_size) {
639
+ max_beam_size = new_beam_size;
640
+ }
641
+ }
642
+ }
643
+
644
+ // preallocate buffers
645
+ pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
646
+ pool.new_residuals.resize(n * max_beam_size * rq.d);
647
+
648
+ pool.codes.resize(n * max_beam_size * (rq.M + 1));
649
+ pool.distances.resize(n * max_beam_size);
650
+ pool.residuals.resize(n * rq.d * max_beam_size);
651
+
652
+ for (size_t i = 0; i < n * rq.d * beam_size; i++) {
653
+ pool.residuals[i] = x[i];
654
+ }
655
+
656
+ // set up pointers to buffers
657
+ int32_t* __restrict codes_ptr = pool.codes.data();
658
+ float* __restrict residuals_ptr = pool.residuals.data();
659
+
660
+ int32_t* __restrict new_codes_ptr = pool.new_codes.data();
661
+ float* __restrict new_residuals_ptr = pool.new_residuals.data();
662
+
663
+ // index
664
+ std::unique_ptr<Index> assign_index;
665
+ if (rq.assign_index_factory) {
666
+ assign_index.reset((*rq.assign_index_factory)(rq.d));
667
+ } else {
668
+ assign_index.reset(new IndexFlatL2(rq.d));
669
+ }
670
+
671
+ // main loop
672
+ size_t codes_size = 0;
673
+ size_t distances_size = 0;
674
+ size_t residuals_size = 0;
675
+
676
+ for (int m = 0; m < rq.M; m++) {
677
+ int K = 1 << rq.nbits[m];
678
+
679
+ const float* __restrict codebooks_m =
680
+ rq.codebooks.data() + rq.codebook_offsets[m] * rq.d;
681
+
682
+ const int new_beam_size = std::min(cur_beam_size * K, out_beam_size);
683
+
684
+ codes_size = n * new_beam_size * (m + 1);
685
+ residuals_size = n * new_beam_size * rq.d;
686
+ distances_size = n * new_beam_size;
687
+
688
+ beam_search_encode_step(
689
+ rq.d,
690
+ K,
691
+ codebooks_m,
692
+ n,
693
+ cur_beam_size,
694
+ residuals_ptr,
695
+ m,
696
+ codes_ptr,
697
+ new_beam_size,
698
+ new_codes_ptr,
699
+ new_residuals_ptr,
700
+ pool.distances.data(),
701
+ assign_index.get(),
702
+ rq.approx_topk_mode);
703
+
704
+ assign_index->reset();
705
+
706
+ std::swap(codes_ptr, new_codes_ptr);
707
+ std::swap(residuals_ptr, new_residuals_ptr);
708
+
709
+ cur_beam_size = new_beam_size;
710
+
711
+ if (rq.verbose) {
712
+ float sum_distances = 0;
713
+ for (int j = 0; j < distances_size; j++) {
714
+ sum_distances += pool.distances[j];
715
+ }
716
+
717
+ printf("[%.3f s] encode stage %d, %d bits, "
718
+ "total error %g, beam_size %d\n",
719
+ (getmillisecs() - t0) / 1000,
720
+ m,
721
+ int(rq.nbits[m]),
722
+ sum_distances,
723
+ cur_beam_size);
724
+ }
725
+ }
726
+
727
+ if (out_codes) {
728
+ memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
729
+ }
730
+ if (out_residuals) {
731
+ memcpy(out_residuals,
732
+ residuals_ptr,
733
+ residuals_size * sizeof(*residuals_ptr));
734
+ }
735
+ if (out_distances) {
736
+ memcpy(out_distances,
737
+ pool.distances.data(),
738
+ distances_size * sizeof(pool.distances[0]));
739
+ }
740
+ }
741
+
742
+ void refine_beam_LUT_mp(
743
+ const ResidualQuantizer& rq,
744
+ size_t n,
745
+ const float* query_norms, // size n
746
+ const float* query_cp, //
747
+ int out_beam_size,
748
+ int32_t* out_codes,
749
+ float* out_distances,
750
+ RefineBeamLUTMemoryPool& pool) {
751
+ int beam_size = 1;
752
+
753
+ double t0 = getmillisecs();
754
+
755
+ // find the max_beam_size
756
+ int max_beam_size = 0;
757
+ {
758
+ int tmp_beam_size = beam_size;
759
+ for (int m = 0; m < rq.M; m++) {
760
+ int K = 1 << rq.nbits[m];
761
+ int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
762
+ tmp_beam_size = new_beam_size;
763
+
764
+ if (max_beam_size < new_beam_size) {
765
+ max_beam_size = new_beam_size;
766
+ }
767
+ }
768
+ }
769
+
770
+ // preallocate buffers
771
+ pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
772
+ pool.new_distances.resize(n * max_beam_size);
773
+
774
+ pool.codes.resize(n * max_beam_size * (rq.M + 1));
775
+ pool.distances.resize(n * max_beam_size);
776
+
777
+ for (size_t i = 0; i < n; i++) {
778
+ pool.distances[i] = query_norms[i];
779
+ }
780
+
781
+ // set up pointers to buffers
782
+ int32_t* __restrict new_codes_ptr = pool.new_codes.data();
783
+ float* __restrict new_distances_ptr = pool.new_distances.data();
784
+
785
+ int32_t* __restrict codes_ptr = pool.codes.data();
786
+ float* __restrict distances_ptr = pool.distances.data();
787
+
788
+ // main loop
789
+ size_t codes_size = 0;
790
+ size_t distances_size = 0;
791
+ size_t cross_ofs = 0;
792
+ for (int m = 0; m < rq.M; m++) {
793
+ int K = 1 << rq.nbits[m];
794
+
795
+ // it is guaranteed that (new_beam_size <= max_beam_size)
796
+ int new_beam_size = std::min(beam_size * K, out_beam_size);
797
+
798
+ codes_size = n * new_beam_size * (m + 1);
799
+ distances_size = n * new_beam_size;
800
+ FAISS_THROW_IF_NOT(
801
+ cross_ofs + rq.codebook_offsets[m] * K <=
802
+ rq.codebook_cross_products.size());
803
+ beam_search_encode_step_tab(
804
+ K,
805
+ n,
806
+ beam_size,
807
+ rq.codebook_cross_products.data() + cross_ofs,
808
+ K,
809
+ rq.codebook_offsets.data(),
810
+ query_cp + rq.codebook_offsets[m],
811
+ rq.total_codebook_size,
812
+ rq.cent_norms.data() + rq.codebook_offsets[m],
813
+ m,
814
+ codes_ptr,
815
+ distances_ptr,
816
+ new_beam_size,
817
+ new_codes_ptr,
818
+ new_distances_ptr,
819
+ rq.approx_topk_mode);
820
+ cross_ofs += rq.codebook_offsets[m] * K;
821
+ std::swap(codes_ptr, new_codes_ptr);
822
+ std::swap(distances_ptr, new_distances_ptr);
823
+
824
+ beam_size = new_beam_size;
825
+
826
+ if (rq.verbose) {
827
+ float sum_distances = 0;
828
+ for (int j = 0; j < distances_size; j++) {
829
+ sum_distances += distances_ptr[j];
830
+ }
831
+ printf("[%.3f s] encode stage %d, %d bits, "
832
+ "total error %g, beam_size %d\n",
833
+ (getmillisecs() - t0) / 1000,
834
+ m,
835
+ int(rq.nbits[m]),
836
+ sum_distances,
837
+ beam_size);
838
+ }
839
+ }
840
+ if (out_codes) {
841
+ memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
842
+ }
843
+ if (out_distances) {
844
+ memcpy(out_distances,
845
+ distances_ptr,
846
+ distances_size * sizeof(*distances_ptr));
847
+ }
848
+ }
849
+
850
+ // this is for use_beam_LUT == 0
851
+ void compute_codes_add_centroids_mp_lut0(
852
+ const ResidualQuantizer& rq,
853
+ const float* x,
854
+ uint8_t* codes_out,
855
+ size_t n,
856
+ const float* centroids,
857
+ ComputeCodesAddCentroidsLUT0MemoryPool& pool) {
858
+ pool.codes.resize(rq.max_beam_size * rq.M * n);
859
+ pool.distances.resize(rq.max_beam_size * n);
860
+
861
+ pool.residuals.resize(rq.max_beam_size * n * rq.d);
862
+
863
+ refine_beam_mp(
864
+ rq,
865
+ n,
866
+ 1,
867
+ x,
868
+ rq.max_beam_size,
869
+ pool.codes.data(),
870
+ pool.residuals.data(),
871
+ pool.distances.data(),
872
+ pool.refine_beam_pool);
873
+
874
+ if (rq.search_type == ResidualQuantizer::ST_norm_float ||
875
+ rq.search_type == ResidualQuantizer::ST_norm_qint8 ||
876
+ rq.search_type == ResidualQuantizer::ST_norm_qint4) {
877
+ pool.norms.resize(n);
878
+ // recover the norms of reconstruction as
879
+ // || original_vector - residual ||^2
880
+ for (size_t i = 0; i < n; i++) {
881
+ pool.norms[i] = fvec_L2sqr(
882
+ x + i * rq.d,
883
+ pool.residuals.data() + i * rq.max_beam_size * rq.d,
884
+ rq.d);
885
+ }
886
+ }
887
+
888
+ // pack only the first code of the beam
889
+ // (hence the ld_codes=M * max_beam_size)
890
+ rq.pack_codes(
891
+ n,
892
+ pool.codes.data(),
893
+ codes_out,
894
+ rq.M * rq.max_beam_size,
895
+ (pool.norms.size() > 0) ? pool.norms.data() : nullptr,
896
+ centroids);
897
+ }
898
+
899
+ // use_beam_LUT == 1
900
+ void compute_codes_add_centroids_mp_lut1(
901
+ const ResidualQuantizer& rq,
902
+ const float* x,
903
+ uint8_t* codes_out,
904
+ size_t n,
905
+ const float* centroids,
906
+ ComputeCodesAddCentroidsLUT1MemoryPool& pool) {
907
+ //
908
+ pool.codes.resize(rq.max_beam_size * rq.M * n);
909
+ pool.distances.resize(rq.max_beam_size * n);
910
+
911
+ FAISS_THROW_IF_NOT_MSG(
912
+ rq.M == 1 || rq.codebook_cross_products.size() > 0,
913
+ "call compute_codebook_tables first");
914
+
915
+ pool.query_norms.resize(n);
916
+ fvec_norms_L2sqr(pool.query_norms.data(), x, rq.d, n);
917
+
918
+ pool.query_cp.resize(n * rq.total_codebook_size);
919
+ {
920
+ FINTEGER ti = rq.total_codebook_size, di = rq.d, ni = n;
921
+ float zero = 0, one = 1;
922
+ sgemm_("Transposed",
923
+ "Not transposed",
924
+ &ti,
925
+ &ni,
926
+ &di,
927
+ &one,
928
+ rq.codebooks.data(),
929
+ &di,
930
+ x,
931
+ &di,
932
+ &zero,
933
+ pool.query_cp.data(),
934
+ &ti);
935
+ }
936
+
937
+ refine_beam_LUT_mp(
938
+ rq,
939
+ n,
940
+ pool.query_norms.data(),
941
+ pool.query_cp.data(),
942
+ rq.max_beam_size,
943
+ pool.codes.data(),
944
+ pool.distances.data(),
945
+ pool.refine_beam_lut_pool);
946
+
947
+ // pack only the first code of the beam
948
+ // (hence the ld_codes=M * max_beam_size)
949
+ rq.pack_codes(
950
+ n,
951
+ pool.codes.data(),
952
+ codes_out,
953
+ rq.M * rq.max_beam_size,
954
+ nullptr,
955
+ centroids);
956
+ }
957
+
958
+ } // namespace rq_encode_steps
959
+
960
+ } // namespace faiss