faiss 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (216) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +1 -1
  5. data/ext/faiss/extconf.rb +9 -2
  6. data/ext/faiss/index.cpp +1 -1
  7. data/ext/faiss/index_binary.cpp +2 -2
  8. data/ext/faiss/product_quantizer.cpp +1 -1
  9. data/lib/faiss/version.rb +1 -1
  10. data/vendor/faiss/faiss/AutoTune.cpp +7 -7
  11. data/vendor/faiss/faiss/AutoTune.h +1 -2
  12. data/vendor/faiss/faiss/Clustering.cpp +39 -22
  13. data/vendor/faiss/faiss/Clustering.h +40 -21
  14. data/vendor/faiss/faiss/IVFlib.cpp +26 -12
  15. data/vendor/faiss/faiss/Index.cpp +1 -1
  16. data/vendor/faiss/faiss/Index.h +40 -10
  17. data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
  18. data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
  19. data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
  20. data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
  21. data/vendor/faiss/faiss/IndexBinary.h +8 -19
  22. data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
  23. data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
  24. data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
  25. data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
  26. data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
  27. data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
  28. data/vendor/faiss/faiss/IndexFastScan.h +9 -8
  29. data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
  30. data/vendor/faiss/faiss/IndexFlat.h +20 -1
  31. data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
  32. data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
  33. data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
  34. data/vendor/faiss/faiss/IndexHNSW.h +62 -49
  35. data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
  36. data/vendor/faiss/faiss/IndexIDMap.h +24 -2
  37. data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
  38. data/vendor/faiss/faiss/IndexIVF.h +46 -6
  39. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
  40. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
  41. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
  42. data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
  43. data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
  44. data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
  45. data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
  46. data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
  47. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
  48. data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
  49. data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
  50. data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
  51. data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
  52. data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
  53. data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
  54. data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
  55. data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
  56. data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
  57. data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
  58. data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
  59. data/vendor/faiss/faiss/IndexLattice.h +3 -22
  60. data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
  61. data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
  62. data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
  63. data/vendor/faiss/faiss/IndexNSG.h +11 -11
  64. data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
  65. data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
  66. data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
  67. data/vendor/faiss/faiss/IndexPQ.h +1 -4
  68. data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
  69. data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
  70. data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
  71. data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
  72. data/vendor/faiss/faiss/IndexRefine.h +7 -0
  73. data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
  74. data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
  75. data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
  76. data/vendor/faiss/faiss/IndexShards.cpp +21 -29
  77. data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
  78. data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
  79. data/vendor/faiss/faiss/MatrixStats.h +21 -9
  80. data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
  81. data/vendor/faiss/faiss/MetricType.h +7 -2
  82. data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
  83. data/vendor/faiss/faiss/VectorTransform.h +7 -7
  84. data/vendor/faiss/faiss/clone_index.cpp +15 -10
  85. data/vendor/faiss/faiss/clone_index.h +3 -0
  86. data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
  87. data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
  88. data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
  89. data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
  90. data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
  91. data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
  92. data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
  93. data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
  94. data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
  95. data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
  96. data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
  97. data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
  98. data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
  99. data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
  100. data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
  101. data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
  102. data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
  103. data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
  104. data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
  105. data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
  106. data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
  107. data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
  108. data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
  109. data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
  110. data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
  111. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
  112. data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
  113. data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
  114. data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
  115. data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
  116. data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
  117. data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
  118. data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
  119. data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
  120. data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
  121. data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
  122. data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
  123. data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
  124. data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
  125. data/vendor/faiss/faiss/impl/FaissException.h +13 -34
  126. data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
  127. data/vendor/faiss/faiss/impl/HNSW.h +52 -30
  128. data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
  129. data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
  130. data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
  131. data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
  132. data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
  133. data/vendor/faiss/faiss/impl/NSG.h +1 -1
  134. data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
  135. data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
  136. data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
  137. data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
  138. data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
  139. data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
  140. data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
  141. data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
  142. data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
  143. data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
  144. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
  145. data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
  146. data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
  147. data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
  148. data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
  149. data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
  150. data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
  151. data/vendor/faiss/faiss/impl/io.cpp +23 -15
  152. data/vendor/faiss/faiss/impl/io.h +4 -4
  153. data/vendor/faiss/faiss/impl/io_macros.h +6 -0
  154. data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
  155. data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
  156. data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
  157. data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
  158. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
  159. data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
  160. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
  161. data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
  162. data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
  163. data/vendor/faiss/faiss/index_factory.cpp +41 -20
  164. data/vendor/faiss/faiss/index_io.h +12 -5
  165. data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
  166. data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
  167. data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
  168. data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
  169. data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
  170. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
  171. data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
  172. data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
  173. data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
  174. data/vendor/faiss/faiss/utils/Heap.h +105 -0
  175. data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
  176. data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
  177. data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
  178. data/vendor/faiss/faiss/utils/bf16.h +36 -0
  179. data/vendor/faiss/faiss/utils/distances.cpp +147 -123
  180. data/vendor/faiss/faiss/utils/distances.h +86 -9
  181. data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
  182. data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
  183. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
  184. data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
  185. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
  186. data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
  187. data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
  188. data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
  189. data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
  190. data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
  191. data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
  192. data/vendor/faiss/faiss/utils/fp16.h +2 -0
  193. data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
  194. data/vendor/faiss/faiss/utils/hamming.h +58 -0
  195. data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
  196. data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
  197. data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
  198. data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
  199. data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
  200. data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
  201. data/vendor/faiss/faiss/utils/prefetch.h +77 -0
  202. data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
  203. data/vendor/faiss/faiss/utils/random.cpp +43 -0
  204. data/vendor/faiss/faiss/utils/random.h +25 -0
  205. data/vendor/faiss/faiss/utils/simdlib.h +10 -1
  206. data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
  207. data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
  208. data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
  209. data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
  210. data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
  211. data/vendor/faiss/faiss/utils/sorting.h +27 -0
  212. data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
  213. data/vendor/faiss/faiss/utils/utils.cpp +120 -7
  214. data/vendor/faiss/faiss/utils/utils.h +60 -20
  215. metadata +23 -4
  216. data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
@@ -0,0 +1,960 @@
1
+ /**
2
+ * Copyright (c) Facebook, Inc. and its affiliates.
3
+ *
4
+ * This source code is licensed under the MIT license found in the
5
+ * LICENSE file in the root directory of this source tree.
6
+ */
7
+
8
+ #include <faiss/impl/residual_quantizer_encode_steps.h>
9
+
10
+ #include <faiss/impl/AuxIndexStructures.h>
11
+ #include <faiss/impl/FaissAssert.h>
12
+ #include <faiss/impl/ResidualQuantizer.h>
13
+ #include <faiss/utils/Heap.h>
14
+ #include <faiss/utils/distances.h>
15
+ #include <faiss/utils/simdlib.h>
16
+ #include <faiss/utils/utils.h>
17
+
18
+ #include <faiss/utils/approx_topk/approx_topk.h>
19
+
20
+ extern "C" {
21
+
22
+ // general matrix multiplication
23
+ int sgemm_(
24
+ const char* transa,
25
+ const char* transb,
26
+ FINTEGER* m,
27
+ FINTEGER* n,
28
+ FINTEGER* k,
29
+ const float* alpha,
30
+ const float* a,
31
+ FINTEGER* lda,
32
+ const float* b,
33
+ FINTEGER* ldb,
34
+ float* beta,
35
+ float* c,
36
+ FINTEGER* ldc);
37
+ }
38
+
39
+ namespace faiss {
40
+
41
+ /********************************************************************
42
+ * Basic routines
43
+ ********************************************************************/
44
+
45
+ namespace {
46
+
47
+ template <size_t M, size_t NK>
48
+ void accum_and_store_tab(
49
+ const size_t m_offset,
50
+ const float* const __restrict codebook_cross_norms,
51
+ const uint64_t* const __restrict codebook_offsets,
52
+ const int32_t* const __restrict codes_i,
53
+ const size_t b,
54
+ const size_t ldc,
55
+ const size_t K,
56
+ float* const __restrict output) {
57
+ // load pointers into registers
58
+ const float* cbs[M];
59
+ for (size_t ij = 0; ij < M; ij++) {
60
+ const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
61
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
62
+ }
63
+
64
+ // do accumulation in registers using SIMD.
65
+ // It is possible that compiler may be smart enough so that
66
+ // this manual SIMD unrolling might be unneeded.
67
+ #if defined(__AVX2__) || defined(__aarch64__)
68
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
69
+
70
+ // process in chunks of size (8 * NK) floats
71
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
72
+ simd8float32 regs[NK];
73
+ for (size_t ik = 0; ik < NK; ik++) {
74
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
75
+ }
76
+
77
+ for (size_t ij = 1; ij < M; ij++) {
78
+ for (size_t ik = 0; ik < NK; ik++) {
79
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
80
+ }
81
+ }
82
+
83
+ // write the result
84
+ for (size_t ik = 0; ik < NK; ik++) {
85
+ regs[ik].storeu(output + kk + ik * 8);
86
+ }
87
+ }
88
+ #else
89
+ const size_t K8 = 0;
90
+ #endif
91
+
92
+ // process leftovers
93
+ for (size_t kk = K8; kk < K; kk++) {
94
+ float reg = cbs[0][kk];
95
+ for (size_t ij = 1; ij < M; ij++) {
96
+ reg += cbs[ij][kk];
97
+ }
98
+ output[kk] = reg;
99
+ }
100
+ }
101
+
102
+ template <size_t M, size_t NK>
103
+ void accum_and_add_tab(
104
+ const size_t m_offset,
105
+ const float* const __restrict codebook_cross_norms,
106
+ const uint64_t* const __restrict codebook_offsets,
107
+ const int32_t* const __restrict codes_i,
108
+ const size_t b,
109
+ const size_t ldc,
110
+ const size_t K,
111
+ float* const __restrict output) {
112
+ // load pointers into registers
113
+ const float* cbs[M];
114
+ for (size_t ij = 0; ij < M; ij++) {
115
+ const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
116
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
117
+ }
118
+
119
+ // do accumulation in registers using SIMD.
120
+ // It is possible that compiler may be smart enough so that
121
+ // this manual SIMD unrolling might be unneeded.
122
+ #if defined(__AVX2__) || defined(__aarch64__)
123
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
124
+
125
+ // process in chunks of size (8 * NK) floats
126
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
127
+ simd8float32 regs[NK];
128
+ for (size_t ik = 0; ik < NK; ik++) {
129
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
130
+ }
131
+
132
+ for (size_t ij = 1; ij < M; ij++) {
133
+ for (size_t ik = 0; ik < NK; ik++) {
134
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
135
+ }
136
+ }
137
+
138
+ // write the result
139
+ for (size_t ik = 0; ik < NK; ik++) {
140
+ simd8float32 existing(output + kk + ik * 8);
141
+ existing += regs[ik];
142
+ existing.storeu(output + kk + ik * 8);
143
+ }
144
+ }
145
+ #else
146
+ const size_t K8 = 0;
147
+ #endif
148
+
149
+ // process leftovers
150
+ for (size_t kk = K8; kk < K; kk++) {
151
+ float reg = cbs[0][kk];
152
+ for (size_t ij = 1; ij < M; ij++) {
153
+ reg += cbs[ij][kk];
154
+ }
155
+ output[kk] += reg;
156
+ }
157
+ }
158
+
159
+ template <size_t M, size_t NK>
160
+ void accum_and_finalize_tab(
161
+ const float* const __restrict codebook_cross_norms,
162
+ const uint64_t* const __restrict codebook_offsets,
163
+ const int32_t* const __restrict codes_i,
164
+ const size_t b,
165
+ const size_t ldc,
166
+ const size_t K,
167
+ const float* const __restrict distances_i,
168
+ const float* const __restrict cd_common,
169
+ float* const __restrict output) {
170
+ // load pointers into registers
171
+ const float* cbs[M];
172
+ for (size_t ij = 0; ij < M; ij++) {
173
+ const size_t code = static_cast<size_t>(codes_i[b * M + ij]);
174
+ cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
175
+ }
176
+
177
+ // do accumulation in registers using SIMD.
178
+ // It is possible that compiler may be smart enough so that
179
+ // this manual SIMD unrolling might be unneeded.
180
+ #if defined(__AVX2__) || defined(__aarch64__)
181
+ const size_t K8 = (K / (8 * NK)) * (8 * NK);
182
+
183
+ // process in chunks of size (8 * NK) floats
184
+ for (size_t kk = 0; kk < K8; kk += 8 * NK) {
185
+ simd8float32 regs[NK];
186
+ for (size_t ik = 0; ik < NK; ik++) {
187
+ regs[ik].loadu(cbs[0] + kk + ik * 8);
188
+ }
189
+
190
+ for (size_t ij = 1; ij < M; ij++) {
191
+ for (size_t ik = 0; ik < NK; ik++) {
192
+ regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
193
+ }
194
+ }
195
+
196
+ simd8float32 two(2.0f);
197
+ for (size_t ik = 0; ik < NK; ik++) {
198
+ // cent_distances[b * K + k] = distances_i[b] + cd_common[k]
199
+ // + 2 * dp[k];
200
+
201
+ simd8float32 common_v(cd_common + kk + ik * 8);
202
+ common_v = fmadd(two, regs[ik], common_v);
203
+
204
+ common_v += simd8float32(distances_i[b]);
205
+ common_v.storeu(output + b * K + kk + ik * 8);
206
+ }
207
+ }
208
+ #else
209
+ const size_t K8 = 0;
210
+ #endif
211
+
212
+ // process leftovers
213
+ for (size_t kk = K8; kk < K; kk++) {
214
+ float reg = cbs[0][kk];
215
+ for (size_t ij = 1; ij < M; ij++) {
216
+ reg += cbs[ij][kk];
217
+ }
218
+
219
+ output[b * K + kk] = distances_i[b] + cd_common[kk] + 2 * reg;
220
+ }
221
+ }
222
+
223
+ } // anonymous namespace
224
+
225
+ /********************************************************************
226
+ * Single encoding step
227
+ ********************************************************************/
228
+
229
+ void beam_search_encode_step(
230
+ size_t d,
231
+ size_t K,
232
+ const float* cent, /// size (K, d)
233
+ size_t n,
234
+ size_t beam_size,
235
+ const float* residuals, /// size (n, beam_size, d)
236
+ size_t m,
237
+ const int32_t* codes, /// size (n, beam_size, m)
238
+ size_t new_beam_size,
239
+ int32_t* new_codes, /// size (n, new_beam_size, m + 1)
240
+ float* new_residuals, /// size (n, new_beam_size, d)
241
+ float* new_distances, /// size (n, new_beam_size)
242
+ Index* assign_index,
243
+ ApproxTopK_mode_t approx_topk_mode) {
244
+ // we have to fill in the whole output matrix
245
+ FAISS_THROW_IF_NOT(new_beam_size <= beam_size * K);
246
+
247
+ std::vector<float> cent_distances;
248
+ std::vector<idx_t> cent_ids;
249
+
250
+ if (assign_index) {
251
+ // search beam_size distances per query
252
+ FAISS_THROW_IF_NOT(assign_index->d == d);
253
+ cent_distances.resize(n * beam_size * new_beam_size);
254
+ cent_ids.resize(n * beam_size * new_beam_size);
255
+ if (assign_index->ntotal != 0) {
256
+ // then we assume the codebooks are already added to the index
257
+ FAISS_THROW_IF_NOT(assign_index->ntotal == K);
258
+ } else {
259
+ assign_index->add(K, cent);
260
+ }
261
+
262
+ // printf("beam_search_encode_step -- mem usage %zd\n",
263
+ // get_mem_usage_kb());
264
+ assign_index->search(
265
+ n * beam_size,
266
+ residuals,
267
+ new_beam_size,
268
+ cent_distances.data(),
269
+ cent_ids.data());
270
+ } else {
271
+ // do one big distance computation
272
+ cent_distances.resize(n * beam_size * K);
273
+ pairwise_L2sqr(
274
+ d, n * beam_size, residuals, K, cent, cent_distances.data());
275
+ }
276
+ InterruptCallback::check();
277
+
278
+ #pragma omp parallel for if (n > 100)
279
+ for (int64_t i = 0; i < n; i++) {
280
+ const int32_t* codes_i = codes + i * m * beam_size;
281
+ int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
282
+ const float* residuals_i = residuals + i * d * beam_size;
283
+ float* new_residuals_i = new_residuals + i * d * new_beam_size;
284
+
285
+ float* new_distances_i = new_distances + i * new_beam_size;
286
+ using C = CMax<float, int>;
287
+
288
+ if (assign_index) {
289
+ const float* cent_distances_i =
290
+ cent_distances.data() + i * beam_size * new_beam_size;
291
+ const idx_t* cent_ids_i =
292
+ cent_ids.data() + i * beam_size * new_beam_size;
293
+
294
+ // here we could be a tad more efficient by merging sorted arrays
295
+ for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
296
+ new_distances_i[i_2] = C::neutral();
297
+ }
298
+ std::vector<int> perm(new_beam_size, -1);
299
+ heap_addn<C>(
300
+ new_beam_size,
301
+ new_distances_i,
302
+ perm.data(),
303
+ cent_distances_i,
304
+ nullptr,
305
+ beam_size * new_beam_size);
306
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
307
+
308
+ for (int j = 0; j < new_beam_size; j++) {
309
+ int js = perm[j] / new_beam_size;
310
+ int ls = cent_ids_i[perm[j]];
311
+ if (m > 0) {
312
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
313
+ }
314
+ new_codes_i[m] = ls;
315
+ new_codes_i += m + 1;
316
+ fvec_sub(
317
+ d,
318
+ residuals_i + js * d,
319
+ cent + ls * d,
320
+ new_residuals_i);
321
+ new_residuals_i += d;
322
+ }
323
+
324
+ } else {
325
+ const float* cent_distances_i =
326
+ cent_distances.data() + i * beam_size * K;
327
+ // then we have to select the best results
328
+ for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
329
+ new_distances_i[i_2] = C::neutral();
330
+ }
331
+ std::vector<int> perm(new_beam_size, -1);
332
+
333
+ #define HANDLE_APPROX(NB, BD) \
334
+ case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
335
+ HeapWithBuckets<C, NB, BD>::bs_addn( \
336
+ beam_size, \
337
+ K, \
338
+ cent_distances_i, \
339
+ new_beam_size, \
340
+ new_distances_i, \
341
+ perm.data()); \
342
+ break;
343
+
344
+ switch (approx_topk_mode) {
345
+ HANDLE_APPROX(8, 3)
346
+ HANDLE_APPROX(8, 2)
347
+ HANDLE_APPROX(16, 2)
348
+ HANDLE_APPROX(32, 2)
349
+ default:
350
+ heap_addn<C>(
351
+ new_beam_size,
352
+ new_distances_i,
353
+ perm.data(),
354
+ cent_distances_i,
355
+ nullptr,
356
+ beam_size * K);
357
+ }
358
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
359
+
360
+ #undef HANDLE_APPROX
361
+
362
+ for (int j = 0; j < new_beam_size; j++) {
363
+ int js = perm[j] / K;
364
+ int ls = perm[j] % K;
365
+ if (m > 0) {
366
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
367
+ }
368
+ new_codes_i[m] = ls;
369
+ new_codes_i += m + 1;
370
+ fvec_sub(
371
+ d,
372
+ residuals_i + js * d,
373
+ cent + ls * d,
374
+ new_residuals_i);
375
+ new_residuals_i += d;
376
+ }
377
+ }
378
+ }
379
+ }
380
+
381
+ // exposed in the faiss namespace
382
+ void beam_search_encode_step_tab(
383
+ size_t K,
384
+ size_t n,
385
+ size_t beam_size, // input sizes
386
+ const float* codebook_cross_norms, // size K * ldc
387
+ size_t ldc,
388
+ const uint64_t* codebook_offsets, // m
389
+ const float* query_cp, // size n * ldqc
390
+ size_t ldqc, // >= K
391
+ const float* cent_norms_i, // size K
392
+ size_t m,
393
+ const int32_t* codes, // n * beam_size * m
394
+ const float* distances, // n * beam_size
395
+ size_t new_beam_size,
396
+ int32_t* new_codes, // n * new_beam_size * (m + 1)
397
+ float* new_distances, // n * new_beam_size
398
+ ApproxTopK_mode_t approx_topk_mode) //
399
+ {
400
+ FAISS_THROW_IF_NOT(ldc >= K);
401
+
402
+ #pragma omp parallel for if (n > 100) schedule(dynamic)
403
+ for (int64_t i = 0; i < n; i++) {
404
+ std::vector<float> cent_distances(beam_size * K);
405
+ std::vector<float> cd_common(K);
406
+
407
+ const int32_t* codes_i = codes + i * m * beam_size;
408
+ const float* query_cp_i = query_cp + i * ldqc;
409
+ const float* distances_i = distances + i * beam_size;
410
+
411
+ for (size_t k = 0; k < K; k++) {
412
+ cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
413
+ }
414
+
415
+ bool use_baseline_implementation = false;
416
+
417
+ // This is the baseline implementation. Its primary flaw
418
+ // that it writes way too many info to the temporary buffer
419
+ // called dp.
420
+ //
421
+ // This baseline code is kept intentionally because it is easy to
422
+ // understand what an optimized version optimizes exactly.
423
+ //
424
+ if (use_baseline_implementation) {
425
+ for (size_t b = 0; b < beam_size; b++) {
426
+ std::vector<float> dp(K);
427
+
428
+ for (size_t m1 = 0; m1 < m; m1++) {
429
+ size_t c = codes_i[b * m + m1];
430
+ const float* cb =
431
+ &codebook_cross_norms
432
+ [(codebook_offsets[m1] + c) * ldc];
433
+ fvec_add(K, cb, dp.data(), dp.data());
434
+ }
435
+
436
+ for (size_t k = 0; k < K; k++) {
437
+ cent_distances[b * K + k] =
438
+ distances_i[b] + cd_common[k] + 2 * dp[k];
439
+ }
440
+ }
441
+
442
+ } else {
443
+ // An optimized implementation that avoids using a temporary buffer
444
+ // and does the accumulation in registers.
445
+
446
+ // Compute a sum of NK AQ codes.
447
+ #define ACCUM_AND_FINALIZE_TAB(NK) \
448
+ case NK: \
449
+ for (size_t b = 0; b < beam_size; b++) { \
450
+ accum_and_finalize_tab<NK, 4>( \
451
+ codebook_cross_norms, \
452
+ codebook_offsets, \
453
+ codes_i, \
454
+ b, \
455
+ ldc, \
456
+ K, \
457
+ distances_i, \
458
+ cd_common.data(), \
459
+ cent_distances.data()); \
460
+ } \
461
+ break;
462
+
463
+ // this version contains many switch-case scenarios, but
464
+ // they won't affect branch predictor.
465
+ switch (m) {
466
+ case 0:
467
+ // trivial case
468
+ for (size_t b = 0; b < beam_size; b++) {
469
+ for (size_t k = 0; k < K; k++) {
470
+ cent_distances[b * K + k] =
471
+ distances_i[b] + cd_common[k];
472
+ }
473
+ }
474
+ break;
475
+
476
+ ACCUM_AND_FINALIZE_TAB(1)
477
+ ACCUM_AND_FINALIZE_TAB(2)
478
+ ACCUM_AND_FINALIZE_TAB(3)
479
+ ACCUM_AND_FINALIZE_TAB(4)
480
+ ACCUM_AND_FINALIZE_TAB(5)
481
+ ACCUM_AND_FINALIZE_TAB(6)
482
+ ACCUM_AND_FINALIZE_TAB(7)
483
+
484
+ default: {
485
+ // m >= 8 case.
486
+
487
+ // A temporary buffer has to be used due to the lack of
488
+ // registers. But we'll try to accumulate up to 8 AQ codes
489
+ // in registers and issue a single write operation to the
490
+ // buffer, while the baseline does no accumulation. So, the
491
+ // number of write operations to the temporary buffer is
492
+ // reduced 8x.
493
+
494
+ // allocate a temporary buffer
495
+ std::vector<float> dp(K);
496
+
497
+ for (size_t b = 0; b < beam_size; b++) {
498
+ // Initialize it. Compute a sum of first 8 AQ codes
499
+ // because m >= 8 .
500
+ accum_and_store_tab<8, 4>(
501
+ m,
502
+ codebook_cross_norms,
503
+ codebook_offsets,
504
+ codes_i,
505
+ b,
506
+ ldc,
507
+ K,
508
+ dp.data());
509
+
510
+ #define ACCUM_AND_ADD_TAB(NK) \
511
+ case NK: \
512
+ accum_and_add_tab<NK, 4>( \
513
+ m, \
514
+ codebook_cross_norms, \
515
+ codebook_offsets + im, \
516
+ codes_i + im, \
517
+ b, \
518
+ ldc, \
519
+ K, \
520
+ dp.data()); \
521
+ break;
522
+
523
+ // accumulate up to 8 additional AQ codes into
524
+ // a temporary buffer
525
+ for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
526
+ size_t m_left = m - im;
527
+ if (m_left > 8) {
528
+ m_left = 8;
529
+ }
530
+
531
+ switch (m_left) {
532
+ ACCUM_AND_ADD_TAB(1)
533
+ ACCUM_AND_ADD_TAB(2)
534
+ ACCUM_AND_ADD_TAB(3)
535
+ ACCUM_AND_ADD_TAB(4)
536
+ ACCUM_AND_ADD_TAB(5)
537
+ ACCUM_AND_ADD_TAB(6)
538
+ ACCUM_AND_ADD_TAB(7)
539
+ ACCUM_AND_ADD_TAB(8)
540
+ }
541
+ }
542
+
543
+ // done. finalize the result
544
+ for (size_t k = 0; k < K; k++) {
545
+ cent_distances[b * K + k] =
546
+ distances_i[b] + cd_common[k] + 2 * dp[k];
547
+ }
548
+ }
549
+ }
550
+ }
551
+
552
+ // the optimized implementation ends here
553
+ }
554
+ using C = CMax<float, int>;
555
+ int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
556
+ float* new_distances_i = new_distances + i * new_beam_size;
557
+
558
+ const float* cent_distances_i = cent_distances.data();
559
+
560
+ // then we have to select the best results
561
+ for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
562
+ new_distances_i[i_2] = C::neutral();
563
+ }
564
+ std::vector<int> perm(new_beam_size, -1);
565
+
566
+ #define HANDLE_APPROX(NB, BD) \
567
+ case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
568
+ HeapWithBuckets<C, NB, BD>::bs_addn( \
569
+ beam_size, \
570
+ K, \
571
+ cent_distances_i, \
572
+ new_beam_size, \
573
+ new_distances_i, \
574
+ perm.data()); \
575
+ break;
576
+
577
+ switch (approx_topk_mode) {
578
+ HANDLE_APPROX(8, 3)
579
+ HANDLE_APPROX(8, 2)
580
+ HANDLE_APPROX(16, 2)
581
+ HANDLE_APPROX(32, 2)
582
+ default:
583
+ heap_addn<C>(
584
+ new_beam_size,
585
+ new_distances_i,
586
+ perm.data(),
587
+ cent_distances_i,
588
+ nullptr,
589
+ beam_size * K);
590
+ break;
591
+ }
592
+
593
+ heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
594
+
595
+ #undef HANDLE_APPROX
596
+
597
+ for (int j = 0; j < new_beam_size; j++) {
598
+ int js = perm[j] / K;
599
+ int ls = perm[j] % K;
600
+ if (m > 0) {
601
+ memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
602
+ }
603
+ new_codes_i[m] = ls;
604
+ new_codes_i += m + 1;
605
+ }
606
+ }
607
+ }
608
+
609
+ /********************************************************************
610
+ * Multiple encoding steps
611
+ ********************************************************************/
612
+
613
+ namespace rq_encode_steps {
614
+
615
+ void refine_beam_mp(
616
+ const ResidualQuantizer& rq,
617
+ size_t n,
618
+ size_t beam_size,
619
+ const float* x,
620
+ int out_beam_size,
621
+ int32_t* out_codes,
622
+ float* out_residuals,
623
+ float* out_distances,
624
+ RefineBeamMemoryPool& pool) {
625
+ int cur_beam_size = beam_size;
626
+
627
+ double t0 = getmillisecs();
628
+
629
+ // find the max_beam_size
630
+ int max_beam_size = 0;
631
+ {
632
+ int tmp_beam_size = cur_beam_size;
633
+ for (int m = 0; m < rq.M; m++) {
634
+ int K = 1 << rq.nbits[m];
635
+ int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
636
+ tmp_beam_size = new_beam_size;
637
+
638
+ if (max_beam_size < new_beam_size) {
639
+ max_beam_size = new_beam_size;
640
+ }
641
+ }
642
+ }
643
+
644
+ // preallocate buffers
645
+ pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
646
+ pool.new_residuals.resize(n * max_beam_size * rq.d);
647
+
648
+ pool.codes.resize(n * max_beam_size * (rq.M + 1));
649
+ pool.distances.resize(n * max_beam_size);
650
+ pool.residuals.resize(n * rq.d * max_beam_size);
651
+
652
+ for (size_t i = 0; i < n * rq.d * beam_size; i++) {
653
+ pool.residuals[i] = x[i];
654
+ }
655
+
656
+ // set up pointers to buffers
657
+ int32_t* __restrict codes_ptr = pool.codes.data();
658
+ float* __restrict residuals_ptr = pool.residuals.data();
659
+
660
+ int32_t* __restrict new_codes_ptr = pool.new_codes.data();
661
+ float* __restrict new_residuals_ptr = pool.new_residuals.data();
662
+
663
+ // index
664
+ std::unique_ptr<Index> assign_index;
665
+ if (rq.assign_index_factory) {
666
+ assign_index.reset((*rq.assign_index_factory)(rq.d));
667
+ }
668
+
669
+ // main loop
670
+ size_t codes_size = 0;
671
+ size_t distances_size = 0;
672
+ size_t residuals_size = 0;
673
+
674
+ for (int m = 0; m < rq.M; m++) {
675
+ int K = 1 << rq.nbits[m];
676
+
677
+ const float* __restrict codebooks_m =
678
+ rq.codebooks.data() + rq.codebook_offsets[m] * rq.d;
679
+
680
+ const int new_beam_size = std::min(cur_beam_size * K, out_beam_size);
681
+
682
+ codes_size = n * new_beam_size * (m + 1);
683
+ residuals_size = n * new_beam_size * rq.d;
684
+ distances_size = n * new_beam_size;
685
+
686
+ beam_search_encode_step(
687
+ rq.d,
688
+ K,
689
+ codebooks_m,
690
+ n,
691
+ cur_beam_size,
692
+ residuals_ptr,
693
+ m,
694
+ codes_ptr,
695
+ new_beam_size,
696
+ new_codes_ptr,
697
+ new_residuals_ptr,
698
+ pool.distances.data(),
699
+ assign_index.get(),
700
+ rq.approx_topk_mode);
701
+
702
+ if (assign_index != nullptr) {
703
+ assign_index->reset();
704
+ }
705
+
706
+ std::swap(codes_ptr, new_codes_ptr);
707
+ std::swap(residuals_ptr, new_residuals_ptr);
708
+
709
+ cur_beam_size = new_beam_size;
710
+
711
+ if (rq.verbose) {
712
+ float sum_distances = 0;
713
+ for (int j = 0; j < distances_size; j++) {
714
+ sum_distances += pool.distances[j];
715
+ }
716
+
717
+ printf("[%.3f s] encode stage %d, %d bits, "
718
+ "total error %g, beam_size %d\n",
719
+ (getmillisecs() - t0) / 1000,
720
+ m,
721
+ int(rq.nbits[m]),
722
+ sum_distances,
723
+ cur_beam_size);
724
+ }
725
+ }
726
+
727
+ if (out_codes) {
728
+ memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
729
+ }
730
+ if (out_residuals) {
731
+ memcpy(out_residuals,
732
+ residuals_ptr,
733
+ residuals_size * sizeof(*residuals_ptr));
734
+ }
735
+ if (out_distances) {
736
+ memcpy(out_distances,
737
+ pool.distances.data(),
738
+ distances_size * sizeof(pool.distances[0]));
739
+ }
740
+ }
741
+
742
+ void refine_beam_LUT_mp(
743
+ const ResidualQuantizer& rq,
744
+ size_t n,
745
+ const float* query_norms, // size n
746
+ const float* query_cp, //
747
+ int out_beam_size,
748
+ int32_t* out_codes,
749
+ float* out_distances,
750
+ RefineBeamLUTMemoryPool& pool) {
751
+ int beam_size = 1;
752
+
753
+ double t0 = getmillisecs();
754
+
755
+ // find the max_beam_size
756
+ int max_beam_size = 0;
757
+ {
758
+ int tmp_beam_size = beam_size;
759
+ for (int m = 0; m < rq.M; m++) {
760
+ int K = 1 << rq.nbits[m];
761
+ int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
762
+ tmp_beam_size = new_beam_size;
763
+
764
+ if (max_beam_size < new_beam_size) {
765
+ max_beam_size = new_beam_size;
766
+ }
767
+ }
768
+ }
769
+
770
+ // preallocate buffers
771
+ pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
772
+ pool.new_distances.resize(n * max_beam_size);
773
+
774
+ pool.codes.resize(n * max_beam_size * (rq.M + 1));
775
+ pool.distances.resize(n * max_beam_size);
776
+
777
+ for (size_t i = 0; i < n; i++) {
778
+ pool.distances[i] = query_norms[i];
779
+ }
780
+
781
+ // set up pointers to buffers
782
+ int32_t* __restrict new_codes_ptr = pool.new_codes.data();
783
+ float* __restrict new_distances_ptr = pool.new_distances.data();
784
+
785
+ int32_t* __restrict codes_ptr = pool.codes.data();
786
+ float* __restrict distances_ptr = pool.distances.data();
787
+
788
+ // main loop
789
+ size_t codes_size = 0;
790
+ size_t distances_size = 0;
791
+ size_t cross_ofs = 0;
792
+ for (int m = 0; m < rq.M; m++) {
793
+ int K = 1 << rq.nbits[m];
794
+
795
+ // it is guaranteed that (new_beam_size <= max_beam_size)
796
+ int new_beam_size = std::min(beam_size * K, out_beam_size);
797
+
798
+ codes_size = n * new_beam_size * (m + 1);
799
+ distances_size = n * new_beam_size;
800
+ FAISS_THROW_IF_NOT(
801
+ cross_ofs + rq.codebook_offsets[m] * K <=
802
+ rq.codebook_cross_products.size());
803
+ beam_search_encode_step_tab(
804
+ K,
805
+ n,
806
+ beam_size,
807
+ rq.codebook_cross_products.data() + cross_ofs,
808
+ K,
809
+ rq.codebook_offsets.data(),
810
+ query_cp + rq.codebook_offsets[m],
811
+ rq.total_codebook_size,
812
+ rq.cent_norms.data() + rq.codebook_offsets[m],
813
+ m,
814
+ codes_ptr,
815
+ distances_ptr,
816
+ new_beam_size,
817
+ new_codes_ptr,
818
+ new_distances_ptr,
819
+ rq.approx_topk_mode);
820
+ cross_ofs += rq.codebook_offsets[m] * K;
821
+ std::swap(codes_ptr, new_codes_ptr);
822
+ std::swap(distances_ptr, new_distances_ptr);
823
+
824
+ beam_size = new_beam_size;
825
+
826
+ if (rq.verbose) {
827
+ float sum_distances = 0;
828
+ for (int j = 0; j < distances_size; j++) {
829
+ sum_distances += distances_ptr[j];
830
+ }
831
+ printf("[%.3f s] encode stage %d, %d bits, "
832
+ "total error %g, beam_size %d\n",
833
+ (getmillisecs() - t0) / 1000,
834
+ m,
835
+ int(rq.nbits[m]),
836
+ sum_distances,
837
+ beam_size);
838
+ }
839
+ }
840
+ if (out_codes) {
841
+ memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
842
+ }
843
+ if (out_distances) {
844
+ memcpy(out_distances,
845
+ distances_ptr,
846
+ distances_size * sizeof(*distances_ptr));
847
+ }
848
+ }
849
+
850
+ // this is for use_beam_LUT == 0
851
+ void compute_codes_add_centroids_mp_lut0(
852
+ const ResidualQuantizer& rq,
853
+ const float* x,
854
+ uint8_t* codes_out,
855
+ size_t n,
856
+ const float* centroids,
857
+ ComputeCodesAddCentroidsLUT0MemoryPool& pool) {
858
+ pool.codes.resize(rq.max_beam_size * rq.M * n);
859
+ pool.distances.resize(rq.max_beam_size * n);
860
+
861
+ pool.residuals.resize(rq.max_beam_size * n * rq.d);
862
+
863
+ refine_beam_mp(
864
+ rq,
865
+ n,
866
+ 1,
867
+ x,
868
+ rq.max_beam_size,
869
+ pool.codes.data(),
870
+ pool.residuals.data(),
871
+ pool.distances.data(),
872
+ pool.refine_beam_pool);
873
+
874
+ if (rq.search_type == ResidualQuantizer::ST_norm_float ||
875
+ rq.search_type == ResidualQuantizer::ST_norm_qint8 ||
876
+ rq.search_type == ResidualQuantizer::ST_norm_qint4) {
877
+ pool.norms.resize(n);
878
+ // recover the norms of reconstruction as
879
+ // || original_vector - residual ||^2
880
+ for (size_t i = 0; i < n; i++) {
881
+ pool.norms[i] = fvec_L2sqr(
882
+ x + i * rq.d,
883
+ pool.residuals.data() + i * rq.max_beam_size * rq.d,
884
+ rq.d);
885
+ }
886
+ }
887
+
888
+ // pack only the first code of the beam
889
+ // (hence the ld_codes=M * max_beam_size)
890
+ rq.pack_codes(
891
+ n,
892
+ pool.codes.data(),
893
+ codes_out,
894
+ rq.M * rq.max_beam_size,
895
+ (pool.norms.size() > 0) ? pool.norms.data() : nullptr,
896
+ centroids);
897
+ }
898
+
899
+ // use_beam_LUT == 1
900
+ void compute_codes_add_centroids_mp_lut1(
901
+ const ResidualQuantizer& rq,
902
+ const float* x,
903
+ uint8_t* codes_out,
904
+ size_t n,
905
+ const float* centroids,
906
+ ComputeCodesAddCentroidsLUT1MemoryPool& pool) {
907
+ //
908
+ pool.codes.resize(rq.max_beam_size * rq.M * n);
909
+ pool.distances.resize(rq.max_beam_size * n);
910
+
911
+ FAISS_THROW_IF_NOT_MSG(
912
+ rq.M == 1 || rq.codebook_cross_products.size() > 0,
913
+ "call compute_codebook_tables first");
914
+
915
+ pool.query_norms.resize(n);
916
+ fvec_norms_L2sqr(pool.query_norms.data(), x, rq.d, n);
917
+
918
+ pool.query_cp.resize(n * rq.total_codebook_size);
919
+ {
920
+ FINTEGER ti = rq.total_codebook_size, di = rq.d, ni = n;
921
+ float zero = 0, one = 1;
922
+ sgemm_("Transposed",
923
+ "Not transposed",
924
+ &ti,
925
+ &ni,
926
+ &di,
927
+ &one,
928
+ rq.codebooks.data(),
929
+ &di,
930
+ x,
931
+ &di,
932
+ &zero,
933
+ pool.query_cp.data(),
934
+ &ti);
935
+ }
936
+
937
+ refine_beam_LUT_mp(
938
+ rq,
939
+ n,
940
+ pool.query_norms.data(),
941
+ pool.query_cp.data(),
942
+ rq.max_beam_size,
943
+ pool.codes.data(),
944
+ pool.distances.data(),
945
+ pool.refine_beam_lut_pool);
946
+
947
+ // pack only the first code of the beam
948
+ // (hence the ld_codes=M * max_beam_size)
949
+ rq.pack_codes(
950
+ n,
951
+ pool.codes.data(),
952
+ codes_out,
953
+ rq.M * rq.max_beam_size,
954
+ nullptr,
955
+ centroids);
956
+ }
957
+
958
+ } // namespace rq_encode_steps
959
+
960
+ } // namespace faiss