faiss 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +1 -1
- data/ext/faiss/extconf.rb +9 -2
- data/ext/faiss/index.cpp +1 -1
- data/ext/faiss/index_binary.cpp +2 -2
- data/ext/faiss/product_quantizer.cpp +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +7 -7
- data/vendor/faiss/faiss/AutoTune.h +1 -2
- data/vendor/faiss/faiss/Clustering.cpp +39 -22
- data/vendor/faiss/faiss/Clustering.h +40 -21
- data/vendor/faiss/faiss/IVFlib.cpp +26 -12
- data/vendor/faiss/faiss/Index.cpp +1 -1
- data/vendor/faiss/faiss/Index.h +40 -10
- data/vendor/faiss/faiss/Index2Layer.cpp +7 -7
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +176 -166
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +15 -15
- data/vendor/faiss/faiss/IndexBinary.cpp +9 -4
- data/vendor/faiss/faiss/IndexBinary.h +8 -19
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +24 -31
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +1 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +25 -50
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +107 -188
- data/vendor/faiss/faiss/IndexFastScan.cpp +95 -146
- data/vendor/faiss/faiss/IndexFastScan.h +9 -8
- data/vendor/faiss/faiss/IndexFlat.cpp +206 -10
- data/vendor/faiss/faiss/IndexFlat.h +20 -1
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +170 -5
- data/vendor/faiss/faiss/IndexFlatCodes.h +23 -4
- data/vendor/faiss/faiss/IndexHNSW.cpp +231 -382
- data/vendor/faiss/faiss/IndexHNSW.h +62 -49
- data/vendor/faiss/faiss/IndexIDMap.cpp +69 -28
- data/vendor/faiss/faiss/IndexIDMap.h +24 -2
- data/vendor/faiss/faiss/IndexIVF.cpp +162 -56
- data/vendor/faiss/faiss/IndexIVF.h +46 -6
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +33 -26
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +6 -2
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +19 -46
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +502 -401
- data/vendor/faiss/faiss/IndexIVFFastScan.h +63 -26
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +15 -5
- data/vendor/faiss/faiss/IndexIVFFlat.h +3 -2
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.cpp +172 -0
- data/vendor/faiss/faiss/IndexIVFIndependentQuantizer.h +56 -0
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +79 -125
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -7
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +39 -52
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +4 -3
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +45 -29
- data/vendor/faiss/faiss/IndexIVFPQR.h +5 -2
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +25 -27
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +6 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +14 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +1 -19
- data/vendor/faiss/faiss/IndexLattice.h +3 -22
- data/vendor/faiss/faiss/IndexNNDescent.cpp +3 -33
- data/vendor/faiss/faiss/IndexNNDescent.h +1 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +11 -27
- data/vendor/faiss/faiss/IndexNSG.h +11 -11
- data/vendor/faiss/faiss/IndexNeuralNetCodec.cpp +56 -0
- data/vendor/faiss/faiss/IndexNeuralNetCodec.h +49 -0
- data/vendor/faiss/faiss/IndexPQ.cpp +72 -88
- data/vendor/faiss/faiss/IndexPQ.h +1 -4
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +1 -1
- data/vendor/faiss/faiss/IndexPreTransform.cpp +25 -31
- data/vendor/faiss/faiss/IndexPreTransform.h +1 -1
- data/vendor/faiss/faiss/IndexRefine.cpp +54 -24
- data/vendor/faiss/faiss/IndexRefine.h +7 -0
- data/vendor/faiss/faiss/IndexReplicas.cpp +23 -26
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +25 -17
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +6 -4
- data/vendor/faiss/faiss/IndexShards.cpp +21 -29
- data/vendor/faiss/faiss/IndexShardsIVF.cpp +1 -2
- data/vendor/faiss/faiss/MatrixStats.cpp +17 -32
- data/vendor/faiss/faiss/MatrixStats.h +21 -9
- data/vendor/faiss/faiss/MetaIndexes.cpp +35 -35
- data/vendor/faiss/faiss/MetricType.h +7 -2
- data/vendor/faiss/faiss/VectorTransform.cpp +13 -26
- data/vendor/faiss/faiss/VectorTransform.h +7 -7
- data/vendor/faiss/faiss/clone_index.cpp +15 -10
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +95 -17
- data/vendor/faiss/faiss/cppcontrib/factory_tools.cpp +152 -0
- data/vendor/faiss/faiss/cppcontrib/factory_tools.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +83 -30
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +123 -8
- data/vendor/faiss/faiss/gpu/GpuCloner.h +22 -0
- data/vendor/faiss/faiss/gpu/GpuClonerOptions.h +13 -0
- data/vendor/faiss/faiss/gpu/GpuDistance.h +46 -38
- data/vendor/faiss/faiss/gpu/GpuFaissAssert.h +1 -1
- data/vendor/faiss/faiss/gpu/GpuIndex.h +30 -12
- data/vendor/faiss/faiss/gpu/GpuIndexCagra.h +282 -0
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +4 -4
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +14 -9
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +20 -3
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -11
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +1 -3
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +24 -3
- data/vendor/faiss/faiss/gpu/GpuResources.h +39 -11
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +142 -17
- data/vendor/faiss/faiss/gpu/StandardGpuResources.h +57 -3
- data/vendor/faiss/faiss/gpu/impl/InterleavedCodes.cpp +26 -21
- data/vendor/faiss/faiss/gpu/perf/PerfClustering.cpp +7 -1
- data/vendor/faiss/faiss/gpu/test/TestCodePacking.cpp +8 -5
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +25 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +129 -9
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +332 -40
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexIVFPQ.cpp +299 -208
- data/vendor/faiss/faiss/gpu/test/TestGpuMemoryException.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +1 -1
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +6 -0
- data/vendor/faiss/faiss/gpu/utils/RaftUtils.h +75 -0
- data/vendor/faiss/faiss/gpu/utils/Timer.cpp +4 -1
- data/vendor/faiss/faiss/gpu/utils/Timer.h +1 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +3 -1
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +5 -5
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +26 -1
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +10 -3
- data/vendor/faiss/faiss/impl/DistanceComputer.h +70 -1
- data/vendor/faiss/faiss/impl/FaissAssert.h +4 -2
- data/vendor/faiss/faiss/impl/FaissException.h +13 -34
- data/vendor/faiss/faiss/impl/HNSW.cpp +605 -186
- data/vendor/faiss/faiss/impl/HNSW.h +52 -30
- data/vendor/faiss/faiss/impl/IDSelector.h +4 -4
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +11 -9
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +34 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +42 -27
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -29
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/PolysemousTraining.cpp +14 -12
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +25 -22
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +6 -2
- data/vendor/faiss/faiss/impl/Quantizer.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +27 -1015
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +5 -63
- data/vendor/faiss/faiss/impl/ResultHandler.h +347 -172
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +1104 -147
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +3 -8
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx2.h +285 -42
- data/vendor/faiss/faiss/impl/code_distance/code_distance-avx512.h +248 -0
- data/vendor/faiss/faiss/impl/code_distance/code_distance-generic.h +21 -14
- data/vendor/faiss/faiss/impl/code_distance/code_distance.h +22 -12
- data/vendor/faiss/faiss/impl/index_read.cpp +74 -34
- data/vendor/faiss/faiss/impl/index_read_utils.h +37 -0
- data/vendor/faiss/faiss/impl/index_write.cpp +88 -51
- data/vendor/faiss/faiss/impl/io.cpp +23 -15
- data/vendor/faiss/faiss/impl/io.h +4 -4
- data/vendor/faiss/faiss/impl/io_macros.h +6 -0
- data/vendor/faiss/faiss/impl/lattice_Zn.cpp +1 -1
- data/vendor/faiss/faiss/impl/platform_macros.h +40 -1
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +14 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +7 -6
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +52 -38
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +487 -49
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.cpp +960 -0
- data/vendor/faiss/faiss/impl/residual_quantizer_encode_steps.h +176 -0
- data/vendor/faiss/faiss/impl/simd_result_handlers.h +481 -225
- data/vendor/faiss/faiss/index_factory.cpp +41 -20
- data/vendor/faiss/faiss/index_io.h +12 -5
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.cpp +28 -8
- data/vendor/faiss/faiss/invlists/BlockInvertedLists.h +3 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +10 -2
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +73 -17
- data/vendor/faiss/faiss/invlists/InvertedLists.h +26 -8
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +24 -9
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.h +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +4 -4
- data/vendor/faiss/faiss/utils/Heap.cpp +3 -1
- data/vendor/faiss/faiss/utils/Heap.h +105 -0
- data/vendor/faiss/faiss/utils/NeuralNet.cpp +342 -0
- data/vendor/faiss/faiss/utils/NeuralNet.h +147 -0
- data/vendor/faiss/faiss/utils/WorkerThread.h +1 -0
- data/vendor/faiss/faiss/utils/bf16.h +36 -0
- data/vendor/faiss/faiss/utils/distances.cpp +147 -123
- data/vendor/faiss/faiss/utils/distances.h +86 -9
- data/vendor/faiss/faiss/utils/distances_fused/avx512.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/avx512.h +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.cpp +2 -2
- data/vendor/faiss/faiss/utils/distances_fused/distances_fused.h +1 -1
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.cpp +5 -5
- data/vendor/faiss/faiss/utils/distances_fused/simdlib_based.h +1 -1
- data/vendor/faiss/faiss/utils/distances_simd.cpp +1589 -243
- data/vendor/faiss/faiss/utils/extra_distances-inl.h +70 -0
- data/vendor/faiss/faiss/utils/extra_distances.cpp +85 -137
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -2
- data/vendor/faiss/faiss/utils/fp16-arm.h +29 -0
- data/vendor/faiss/faiss/utils/fp16.h +2 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +163 -111
- data/vendor/faiss/faiss/utils/hamming.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/avx2-inl.h +16 -89
- data/vendor/faiss/faiss/utils/hamming_distance/common.h +1 -0
- data/vendor/faiss/faiss/utils/hamming_distance/generic-inl.h +19 -88
- data/vendor/faiss/faiss/utils/hamming_distance/hamdis-inl.h +58 -0
- data/vendor/faiss/faiss/utils/hamming_distance/neon-inl.h +14 -104
- data/vendor/faiss/faiss/utils/partitioning.cpp +3 -4
- data/vendor/faiss/faiss/utils/prefetch.h +77 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +0 -14
- data/vendor/faiss/faiss/utils/random.cpp +43 -0
- data/vendor/faiss/faiss/utils/random.h +25 -0
- data/vendor/faiss/faiss/utils/simdlib.h +10 -1
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +0 -6
- data/vendor/faiss/faiss/utils/simdlib_avx512.h +296 -0
- data/vendor/faiss/faiss/utils/simdlib_neon.h +77 -79
- data/vendor/faiss/faiss/utils/simdlib_ppc64.h +1084 -0
- data/vendor/faiss/faiss/utils/sorting.cpp +140 -5
- data/vendor/faiss/faiss/utils/sorting.h +27 -0
- data/vendor/faiss/faiss/utils/transpose/transpose-avx512-inl.h +176 -0
- data/vendor/faiss/faiss/utils/utils.cpp +120 -7
- data/vendor/faiss/faiss/utils/utils.h +60 -20
- metadata +23 -4
- data/vendor/faiss/faiss/impl/code_distance/code_distance_avx512.h +0 -102
|
@@ -0,0 +1,960 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
|
3
|
+
*
|
|
4
|
+
* This source code is licensed under the MIT license found in the
|
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
|
6
|
+
*/
|
|
7
|
+
|
|
8
|
+
#include <faiss/impl/residual_quantizer_encode_steps.h>
|
|
9
|
+
|
|
10
|
+
#include <faiss/impl/AuxIndexStructures.h>
|
|
11
|
+
#include <faiss/impl/FaissAssert.h>
|
|
12
|
+
#include <faiss/impl/ResidualQuantizer.h>
|
|
13
|
+
#include <faiss/utils/Heap.h>
|
|
14
|
+
#include <faiss/utils/distances.h>
|
|
15
|
+
#include <faiss/utils/simdlib.h>
|
|
16
|
+
#include <faiss/utils/utils.h>
|
|
17
|
+
|
|
18
|
+
#include <faiss/utils/approx_topk/approx_topk.h>
|
|
19
|
+
|
|
20
|
+
extern "C" {
|
|
21
|
+
|
|
22
|
+
// general matrix multiplication
|
|
23
|
+
int sgemm_(
|
|
24
|
+
const char* transa,
|
|
25
|
+
const char* transb,
|
|
26
|
+
FINTEGER* m,
|
|
27
|
+
FINTEGER* n,
|
|
28
|
+
FINTEGER* k,
|
|
29
|
+
const float* alpha,
|
|
30
|
+
const float* a,
|
|
31
|
+
FINTEGER* lda,
|
|
32
|
+
const float* b,
|
|
33
|
+
FINTEGER* ldb,
|
|
34
|
+
float* beta,
|
|
35
|
+
float* c,
|
|
36
|
+
FINTEGER* ldc);
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
namespace faiss {
|
|
40
|
+
|
|
41
|
+
/********************************************************************
|
|
42
|
+
* Basic routines
|
|
43
|
+
********************************************************************/
|
|
44
|
+
|
|
45
|
+
namespace {
|
|
46
|
+
|
|
47
|
+
template <size_t M, size_t NK>
|
|
48
|
+
void accum_and_store_tab(
|
|
49
|
+
const size_t m_offset,
|
|
50
|
+
const float* const __restrict codebook_cross_norms,
|
|
51
|
+
const uint64_t* const __restrict codebook_offsets,
|
|
52
|
+
const int32_t* const __restrict codes_i,
|
|
53
|
+
const size_t b,
|
|
54
|
+
const size_t ldc,
|
|
55
|
+
const size_t K,
|
|
56
|
+
float* const __restrict output) {
|
|
57
|
+
// load pointers into registers
|
|
58
|
+
const float* cbs[M];
|
|
59
|
+
for (size_t ij = 0; ij < M; ij++) {
|
|
60
|
+
const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
|
|
61
|
+
cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
// do accumulation in registers using SIMD.
|
|
65
|
+
// It is possible that compiler may be smart enough so that
|
|
66
|
+
// this manual SIMD unrolling might be unneeded.
|
|
67
|
+
#if defined(__AVX2__) || defined(__aarch64__)
|
|
68
|
+
const size_t K8 = (K / (8 * NK)) * (8 * NK);
|
|
69
|
+
|
|
70
|
+
// process in chunks of size (8 * NK) floats
|
|
71
|
+
for (size_t kk = 0; kk < K8; kk += 8 * NK) {
|
|
72
|
+
simd8float32 regs[NK];
|
|
73
|
+
for (size_t ik = 0; ik < NK; ik++) {
|
|
74
|
+
regs[ik].loadu(cbs[0] + kk + ik * 8);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
for (size_t ij = 1; ij < M; ij++) {
|
|
78
|
+
for (size_t ik = 0; ik < NK; ik++) {
|
|
79
|
+
regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
|
|
80
|
+
}
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
// write the result
|
|
84
|
+
for (size_t ik = 0; ik < NK; ik++) {
|
|
85
|
+
regs[ik].storeu(output + kk + ik * 8);
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
#else
|
|
89
|
+
const size_t K8 = 0;
|
|
90
|
+
#endif
|
|
91
|
+
|
|
92
|
+
// process leftovers
|
|
93
|
+
for (size_t kk = K8; kk < K; kk++) {
|
|
94
|
+
float reg = cbs[0][kk];
|
|
95
|
+
for (size_t ij = 1; ij < M; ij++) {
|
|
96
|
+
reg += cbs[ij][kk];
|
|
97
|
+
}
|
|
98
|
+
output[kk] = reg;
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
template <size_t M, size_t NK>
|
|
103
|
+
void accum_and_add_tab(
|
|
104
|
+
const size_t m_offset,
|
|
105
|
+
const float* const __restrict codebook_cross_norms,
|
|
106
|
+
const uint64_t* const __restrict codebook_offsets,
|
|
107
|
+
const int32_t* const __restrict codes_i,
|
|
108
|
+
const size_t b,
|
|
109
|
+
const size_t ldc,
|
|
110
|
+
const size_t K,
|
|
111
|
+
float* const __restrict output) {
|
|
112
|
+
// load pointers into registers
|
|
113
|
+
const float* cbs[M];
|
|
114
|
+
for (size_t ij = 0; ij < M; ij++) {
|
|
115
|
+
const size_t code = static_cast<size_t>(codes_i[b * m_offset + ij]);
|
|
116
|
+
cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
// do accumulation in registers using SIMD.
|
|
120
|
+
// It is possible that compiler may be smart enough so that
|
|
121
|
+
// this manual SIMD unrolling might be unneeded.
|
|
122
|
+
#if defined(__AVX2__) || defined(__aarch64__)
|
|
123
|
+
const size_t K8 = (K / (8 * NK)) * (8 * NK);
|
|
124
|
+
|
|
125
|
+
// process in chunks of size (8 * NK) floats
|
|
126
|
+
for (size_t kk = 0; kk < K8; kk += 8 * NK) {
|
|
127
|
+
simd8float32 regs[NK];
|
|
128
|
+
for (size_t ik = 0; ik < NK; ik++) {
|
|
129
|
+
regs[ik].loadu(cbs[0] + kk + ik * 8);
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
for (size_t ij = 1; ij < M; ij++) {
|
|
133
|
+
for (size_t ik = 0; ik < NK; ik++) {
|
|
134
|
+
regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
// write the result
|
|
139
|
+
for (size_t ik = 0; ik < NK; ik++) {
|
|
140
|
+
simd8float32 existing(output + kk + ik * 8);
|
|
141
|
+
existing += regs[ik];
|
|
142
|
+
existing.storeu(output + kk + ik * 8);
|
|
143
|
+
}
|
|
144
|
+
}
|
|
145
|
+
#else
|
|
146
|
+
const size_t K8 = 0;
|
|
147
|
+
#endif
|
|
148
|
+
|
|
149
|
+
// process leftovers
|
|
150
|
+
for (size_t kk = K8; kk < K; kk++) {
|
|
151
|
+
float reg = cbs[0][kk];
|
|
152
|
+
for (size_t ij = 1; ij < M; ij++) {
|
|
153
|
+
reg += cbs[ij][kk];
|
|
154
|
+
}
|
|
155
|
+
output[kk] += reg;
|
|
156
|
+
}
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
template <size_t M, size_t NK>
|
|
160
|
+
void accum_and_finalize_tab(
|
|
161
|
+
const float* const __restrict codebook_cross_norms,
|
|
162
|
+
const uint64_t* const __restrict codebook_offsets,
|
|
163
|
+
const int32_t* const __restrict codes_i,
|
|
164
|
+
const size_t b,
|
|
165
|
+
const size_t ldc,
|
|
166
|
+
const size_t K,
|
|
167
|
+
const float* const __restrict distances_i,
|
|
168
|
+
const float* const __restrict cd_common,
|
|
169
|
+
float* const __restrict output) {
|
|
170
|
+
// load pointers into registers
|
|
171
|
+
const float* cbs[M];
|
|
172
|
+
for (size_t ij = 0; ij < M; ij++) {
|
|
173
|
+
const size_t code = static_cast<size_t>(codes_i[b * M + ij]);
|
|
174
|
+
cbs[ij] = &codebook_cross_norms[(codebook_offsets[ij] + code) * ldc];
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
// do accumulation in registers using SIMD.
|
|
178
|
+
// It is possible that compiler may be smart enough so that
|
|
179
|
+
// this manual SIMD unrolling might be unneeded.
|
|
180
|
+
#if defined(__AVX2__) || defined(__aarch64__)
|
|
181
|
+
const size_t K8 = (K / (8 * NK)) * (8 * NK);
|
|
182
|
+
|
|
183
|
+
// process in chunks of size (8 * NK) floats
|
|
184
|
+
for (size_t kk = 0; kk < K8; kk += 8 * NK) {
|
|
185
|
+
simd8float32 regs[NK];
|
|
186
|
+
for (size_t ik = 0; ik < NK; ik++) {
|
|
187
|
+
regs[ik].loadu(cbs[0] + kk + ik * 8);
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
for (size_t ij = 1; ij < M; ij++) {
|
|
191
|
+
for (size_t ik = 0; ik < NK; ik++) {
|
|
192
|
+
regs[ik] += simd8float32(cbs[ij] + kk + ik * 8);
|
|
193
|
+
}
|
|
194
|
+
}
|
|
195
|
+
|
|
196
|
+
simd8float32 two(2.0f);
|
|
197
|
+
for (size_t ik = 0; ik < NK; ik++) {
|
|
198
|
+
// cent_distances[b * K + k] = distances_i[b] + cd_common[k]
|
|
199
|
+
// + 2 * dp[k];
|
|
200
|
+
|
|
201
|
+
simd8float32 common_v(cd_common + kk + ik * 8);
|
|
202
|
+
common_v = fmadd(two, regs[ik], common_v);
|
|
203
|
+
|
|
204
|
+
common_v += simd8float32(distances_i[b]);
|
|
205
|
+
common_v.storeu(output + b * K + kk + ik * 8);
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
#else
|
|
209
|
+
const size_t K8 = 0;
|
|
210
|
+
#endif
|
|
211
|
+
|
|
212
|
+
// process leftovers
|
|
213
|
+
for (size_t kk = K8; kk < K; kk++) {
|
|
214
|
+
float reg = cbs[0][kk];
|
|
215
|
+
for (size_t ij = 1; ij < M; ij++) {
|
|
216
|
+
reg += cbs[ij][kk];
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
output[b * K + kk] = distances_i[b] + cd_common[kk] + 2 * reg;
|
|
220
|
+
}
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
} // anonymous namespace
|
|
224
|
+
|
|
225
|
+
/********************************************************************
|
|
226
|
+
* Single encoding step
|
|
227
|
+
********************************************************************/
|
|
228
|
+
|
|
229
|
+
void beam_search_encode_step(
|
|
230
|
+
size_t d,
|
|
231
|
+
size_t K,
|
|
232
|
+
const float* cent, /// size (K, d)
|
|
233
|
+
size_t n,
|
|
234
|
+
size_t beam_size,
|
|
235
|
+
const float* residuals, /// size (n, beam_size, d)
|
|
236
|
+
size_t m,
|
|
237
|
+
const int32_t* codes, /// size (n, beam_size, m)
|
|
238
|
+
size_t new_beam_size,
|
|
239
|
+
int32_t* new_codes, /// size (n, new_beam_size, m + 1)
|
|
240
|
+
float* new_residuals, /// size (n, new_beam_size, d)
|
|
241
|
+
float* new_distances, /// size (n, new_beam_size)
|
|
242
|
+
Index* assign_index,
|
|
243
|
+
ApproxTopK_mode_t approx_topk_mode) {
|
|
244
|
+
// we have to fill in the whole output matrix
|
|
245
|
+
FAISS_THROW_IF_NOT(new_beam_size <= beam_size * K);
|
|
246
|
+
|
|
247
|
+
std::vector<float> cent_distances;
|
|
248
|
+
std::vector<idx_t> cent_ids;
|
|
249
|
+
|
|
250
|
+
if (assign_index) {
|
|
251
|
+
// search beam_size distances per query
|
|
252
|
+
FAISS_THROW_IF_NOT(assign_index->d == d);
|
|
253
|
+
cent_distances.resize(n * beam_size * new_beam_size);
|
|
254
|
+
cent_ids.resize(n * beam_size * new_beam_size);
|
|
255
|
+
if (assign_index->ntotal != 0) {
|
|
256
|
+
// then we assume the codebooks are already added to the index
|
|
257
|
+
FAISS_THROW_IF_NOT(assign_index->ntotal == K);
|
|
258
|
+
} else {
|
|
259
|
+
assign_index->add(K, cent);
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
// printf("beam_search_encode_step -- mem usage %zd\n",
|
|
263
|
+
// get_mem_usage_kb());
|
|
264
|
+
assign_index->search(
|
|
265
|
+
n * beam_size,
|
|
266
|
+
residuals,
|
|
267
|
+
new_beam_size,
|
|
268
|
+
cent_distances.data(),
|
|
269
|
+
cent_ids.data());
|
|
270
|
+
} else {
|
|
271
|
+
// do one big distance computation
|
|
272
|
+
cent_distances.resize(n * beam_size * K);
|
|
273
|
+
pairwise_L2sqr(
|
|
274
|
+
d, n * beam_size, residuals, K, cent, cent_distances.data());
|
|
275
|
+
}
|
|
276
|
+
InterruptCallback::check();
|
|
277
|
+
|
|
278
|
+
#pragma omp parallel for if (n > 100)
|
|
279
|
+
for (int64_t i = 0; i < n; i++) {
|
|
280
|
+
const int32_t* codes_i = codes + i * m * beam_size;
|
|
281
|
+
int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
|
|
282
|
+
const float* residuals_i = residuals + i * d * beam_size;
|
|
283
|
+
float* new_residuals_i = new_residuals + i * d * new_beam_size;
|
|
284
|
+
|
|
285
|
+
float* new_distances_i = new_distances + i * new_beam_size;
|
|
286
|
+
using C = CMax<float, int>;
|
|
287
|
+
|
|
288
|
+
if (assign_index) {
|
|
289
|
+
const float* cent_distances_i =
|
|
290
|
+
cent_distances.data() + i * beam_size * new_beam_size;
|
|
291
|
+
const idx_t* cent_ids_i =
|
|
292
|
+
cent_ids.data() + i * beam_size * new_beam_size;
|
|
293
|
+
|
|
294
|
+
// here we could be a tad more efficient by merging sorted arrays
|
|
295
|
+
for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
|
|
296
|
+
new_distances_i[i_2] = C::neutral();
|
|
297
|
+
}
|
|
298
|
+
std::vector<int> perm(new_beam_size, -1);
|
|
299
|
+
heap_addn<C>(
|
|
300
|
+
new_beam_size,
|
|
301
|
+
new_distances_i,
|
|
302
|
+
perm.data(),
|
|
303
|
+
cent_distances_i,
|
|
304
|
+
nullptr,
|
|
305
|
+
beam_size * new_beam_size);
|
|
306
|
+
heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
|
|
307
|
+
|
|
308
|
+
for (int j = 0; j < new_beam_size; j++) {
|
|
309
|
+
int js = perm[j] / new_beam_size;
|
|
310
|
+
int ls = cent_ids_i[perm[j]];
|
|
311
|
+
if (m > 0) {
|
|
312
|
+
memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
|
|
313
|
+
}
|
|
314
|
+
new_codes_i[m] = ls;
|
|
315
|
+
new_codes_i += m + 1;
|
|
316
|
+
fvec_sub(
|
|
317
|
+
d,
|
|
318
|
+
residuals_i + js * d,
|
|
319
|
+
cent + ls * d,
|
|
320
|
+
new_residuals_i);
|
|
321
|
+
new_residuals_i += d;
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
} else {
|
|
325
|
+
const float* cent_distances_i =
|
|
326
|
+
cent_distances.data() + i * beam_size * K;
|
|
327
|
+
// then we have to select the best results
|
|
328
|
+
for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
|
|
329
|
+
new_distances_i[i_2] = C::neutral();
|
|
330
|
+
}
|
|
331
|
+
std::vector<int> perm(new_beam_size, -1);
|
|
332
|
+
|
|
333
|
+
#define HANDLE_APPROX(NB, BD) \
|
|
334
|
+
case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
|
|
335
|
+
HeapWithBuckets<C, NB, BD>::bs_addn( \
|
|
336
|
+
beam_size, \
|
|
337
|
+
K, \
|
|
338
|
+
cent_distances_i, \
|
|
339
|
+
new_beam_size, \
|
|
340
|
+
new_distances_i, \
|
|
341
|
+
perm.data()); \
|
|
342
|
+
break;
|
|
343
|
+
|
|
344
|
+
switch (approx_topk_mode) {
|
|
345
|
+
HANDLE_APPROX(8, 3)
|
|
346
|
+
HANDLE_APPROX(8, 2)
|
|
347
|
+
HANDLE_APPROX(16, 2)
|
|
348
|
+
HANDLE_APPROX(32, 2)
|
|
349
|
+
default:
|
|
350
|
+
heap_addn<C>(
|
|
351
|
+
new_beam_size,
|
|
352
|
+
new_distances_i,
|
|
353
|
+
perm.data(),
|
|
354
|
+
cent_distances_i,
|
|
355
|
+
nullptr,
|
|
356
|
+
beam_size * K);
|
|
357
|
+
}
|
|
358
|
+
heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
|
|
359
|
+
|
|
360
|
+
#undef HANDLE_APPROX
|
|
361
|
+
|
|
362
|
+
for (int j = 0; j < new_beam_size; j++) {
|
|
363
|
+
int js = perm[j] / K;
|
|
364
|
+
int ls = perm[j] % K;
|
|
365
|
+
if (m > 0) {
|
|
366
|
+
memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
|
|
367
|
+
}
|
|
368
|
+
new_codes_i[m] = ls;
|
|
369
|
+
new_codes_i += m + 1;
|
|
370
|
+
fvec_sub(
|
|
371
|
+
d,
|
|
372
|
+
residuals_i + js * d,
|
|
373
|
+
cent + ls * d,
|
|
374
|
+
new_residuals_i);
|
|
375
|
+
new_residuals_i += d;
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
}
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
// exposed in the faiss namespace
|
|
382
|
+
void beam_search_encode_step_tab(
|
|
383
|
+
size_t K,
|
|
384
|
+
size_t n,
|
|
385
|
+
size_t beam_size, // input sizes
|
|
386
|
+
const float* codebook_cross_norms, // size K * ldc
|
|
387
|
+
size_t ldc,
|
|
388
|
+
const uint64_t* codebook_offsets, // m
|
|
389
|
+
const float* query_cp, // size n * ldqc
|
|
390
|
+
size_t ldqc, // >= K
|
|
391
|
+
const float* cent_norms_i, // size K
|
|
392
|
+
size_t m,
|
|
393
|
+
const int32_t* codes, // n * beam_size * m
|
|
394
|
+
const float* distances, // n * beam_size
|
|
395
|
+
size_t new_beam_size,
|
|
396
|
+
int32_t* new_codes, // n * new_beam_size * (m + 1)
|
|
397
|
+
float* new_distances, // n * new_beam_size
|
|
398
|
+
ApproxTopK_mode_t approx_topk_mode) //
|
|
399
|
+
{
|
|
400
|
+
FAISS_THROW_IF_NOT(ldc >= K);
|
|
401
|
+
|
|
402
|
+
#pragma omp parallel for if (n > 100) schedule(dynamic)
|
|
403
|
+
for (int64_t i = 0; i < n; i++) {
|
|
404
|
+
std::vector<float> cent_distances(beam_size * K);
|
|
405
|
+
std::vector<float> cd_common(K);
|
|
406
|
+
|
|
407
|
+
const int32_t* codes_i = codes + i * m * beam_size;
|
|
408
|
+
const float* query_cp_i = query_cp + i * ldqc;
|
|
409
|
+
const float* distances_i = distances + i * beam_size;
|
|
410
|
+
|
|
411
|
+
for (size_t k = 0; k < K; k++) {
|
|
412
|
+
cd_common[k] = cent_norms_i[k] - 2 * query_cp_i[k];
|
|
413
|
+
}
|
|
414
|
+
|
|
415
|
+
bool use_baseline_implementation = false;
|
|
416
|
+
|
|
417
|
+
// This is the baseline implementation. Its primary flaw
|
|
418
|
+
// that it writes way too many info to the temporary buffer
|
|
419
|
+
// called dp.
|
|
420
|
+
//
|
|
421
|
+
// This baseline code is kept intentionally because it is easy to
|
|
422
|
+
// understand what an optimized version optimizes exactly.
|
|
423
|
+
//
|
|
424
|
+
if (use_baseline_implementation) {
|
|
425
|
+
for (size_t b = 0; b < beam_size; b++) {
|
|
426
|
+
std::vector<float> dp(K);
|
|
427
|
+
|
|
428
|
+
for (size_t m1 = 0; m1 < m; m1++) {
|
|
429
|
+
size_t c = codes_i[b * m + m1];
|
|
430
|
+
const float* cb =
|
|
431
|
+
&codebook_cross_norms
|
|
432
|
+
[(codebook_offsets[m1] + c) * ldc];
|
|
433
|
+
fvec_add(K, cb, dp.data(), dp.data());
|
|
434
|
+
}
|
|
435
|
+
|
|
436
|
+
for (size_t k = 0; k < K; k++) {
|
|
437
|
+
cent_distances[b * K + k] =
|
|
438
|
+
distances_i[b] + cd_common[k] + 2 * dp[k];
|
|
439
|
+
}
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
} else {
|
|
443
|
+
// An optimized implementation that avoids using a temporary buffer
|
|
444
|
+
// and does the accumulation in registers.
|
|
445
|
+
|
|
446
|
+
// Compute a sum of NK AQ codes.
|
|
447
|
+
#define ACCUM_AND_FINALIZE_TAB(NK) \
|
|
448
|
+
case NK: \
|
|
449
|
+
for (size_t b = 0; b < beam_size; b++) { \
|
|
450
|
+
accum_and_finalize_tab<NK, 4>( \
|
|
451
|
+
codebook_cross_norms, \
|
|
452
|
+
codebook_offsets, \
|
|
453
|
+
codes_i, \
|
|
454
|
+
b, \
|
|
455
|
+
ldc, \
|
|
456
|
+
K, \
|
|
457
|
+
distances_i, \
|
|
458
|
+
cd_common.data(), \
|
|
459
|
+
cent_distances.data()); \
|
|
460
|
+
} \
|
|
461
|
+
break;
|
|
462
|
+
|
|
463
|
+
// this version contains many switch-case scenarios, but
|
|
464
|
+
// they won't affect branch predictor.
|
|
465
|
+
switch (m) {
|
|
466
|
+
case 0:
|
|
467
|
+
// trivial case
|
|
468
|
+
for (size_t b = 0; b < beam_size; b++) {
|
|
469
|
+
for (size_t k = 0; k < K; k++) {
|
|
470
|
+
cent_distances[b * K + k] =
|
|
471
|
+
distances_i[b] + cd_common[k];
|
|
472
|
+
}
|
|
473
|
+
}
|
|
474
|
+
break;
|
|
475
|
+
|
|
476
|
+
ACCUM_AND_FINALIZE_TAB(1)
|
|
477
|
+
ACCUM_AND_FINALIZE_TAB(2)
|
|
478
|
+
ACCUM_AND_FINALIZE_TAB(3)
|
|
479
|
+
ACCUM_AND_FINALIZE_TAB(4)
|
|
480
|
+
ACCUM_AND_FINALIZE_TAB(5)
|
|
481
|
+
ACCUM_AND_FINALIZE_TAB(6)
|
|
482
|
+
ACCUM_AND_FINALIZE_TAB(7)
|
|
483
|
+
|
|
484
|
+
default: {
|
|
485
|
+
// m >= 8 case.
|
|
486
|
+
|
|
487
|
+
// A temporary buffer has to be used due to the lack of
|
|
488
|
+
// registers. But we'll try to accumulate up to 8 AQ codes
|
|
489
|
+
// in registers and issue a single write operation to the
|
|
490
|
+
// buffer, while the baseline does no accumulation. So, the
|
|
491
|
+
// number of write operations to the temporary buffer is
|
|
492
|
+
// reduced 8x.
|
|
493
|
+
|
|
494
|
+
// allocate a temporary buffer
|
|
495
|
+
std::vector<float> dp(K);
|
|
496
|
+
|
|
497
|
+
for (size_t b = 0; b < beam_size; b++) {
|
|
498
|
+
// Initialize it. Compute a sum of first 8 AQ codes
|
|
499
|
+
// because m >= 8 .
|
|
500
|
+
accum_and_store_tab<8, 4>(
|
|
501
|
+
m,
|
|
502
|
+
codebook_cross_norms,
|
|
503
|
+
codebook_offsets,
|
|
504
|
+
codes_i,
|
|
505
|
+
b,
|
|
506
|
+
ldc,
|
|
507
|
+
K,
|
|
508
|
+
dp.data());
|
|
509
|
+
|
|
510
|
+
#define ACCUM_AND_ADD_TAB(NK) \
|
|
511
|
+
case NK: \
|
|
512
|
+
accum_and_add_tab<NK, 4>( \
|
|
513
|
+
m, \
|
|
514
|
+
codebook_cross_norms, \
|
|
515
|
+
codebook_offsets + im, \
|
|
516
|
+
codes_i + im, \
|
|
517
|
+
b, \
|
|
518
|
+
ldc, \
|
|
519
|
+
K, \
|
|
520
|
+
dp.data()); \
|
|
521
|
+
break;
|
|
522
|
+
|
|
523
|
+
// accumulate up to 8 additional AQ codes into
|
|
524
|
+
// a temporary buffer
|
|
525
|
+
for (size_t im = 8; im < ((m + 7) / 8) * 8; im += 8) {
|
|
526
|
+
size_t m_left = m - im;
|
|
527
|
+
if (m_left > 8) {
|
|
528
|
+
m_left = 8;
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
switch (m_left) {
|
|
532
|
+
ACCUM_AND_ADD_TAB(1)
|
|
533
|
+
ACCUM_AND_ADD_TAB(2)
|
|
534
|
+
ACCUM_AND_ADD_TAB(3)
|
|
535
|
+
ACCUM_AND_ADD_TAB(4)
|
|
536
|
+
ACCUM_AND_ADD_TAB(5)
|
|
537
|
+
ACCUM_AND_ADD_TAB(6)
|
|
538
|
+
ACCUM_AND_ADD_TAB(7)
|
|
539
|
+
ACCUM_AND_ADD_TAB(8)
|
|
540
|
+
}
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
// done. finalize the result
|
|
544
|
+
for (size_t k = 0; k < K; k++) {
|
|
545
|
+
cent_distances[b * K + k] =
|
|
546
|
+
distances_i[b] + cd_common[k] + 2 * dp[k];
|
|
547
|
+
}
|
|
548
|
+
}
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
|
|
552
|
+
// the optimized implementation ends here
|
|
553
|
+
}
|
|
554
|
+
using C = CMax<float, int>;
|
|
555
|
+
int32_t* new_codes_i = new_codes + i * (m + 1) * new_beam_size;
|
|
556
|
+
float* new_distances_i = new_distances + i * new_beam_size;
|
|
557
|
+
|
|
558
|
+
const float* cent_distances_i = cent_distances.data();
|
|
559
|
+
|
|
560
|
+
// then we have to select the best results
|
|
561
|
+
for (int i_2 = 0; i_2 < new_beam_size; i_2++) {
|
|
562
|
+
new_distances_i[i_2] = C::neutral();
|
|
563
|
+
}
|
|
564
|
+
std::vector<int> perm(new_beam_size, -1);
|
|
565
|
+
|
|
566
|
+
#define HANDLE_APPROX(NB, BD) \
|
|
567
|
+
case ApproxTopK_mode_t::APPROX_TOPK_BUCKETS_B##NB##_D##BD: \
|
|
568
|
+
HeapWithBuckets<C, NB, BD>::bs_addn( \
|
|
569
|
+
beam_size, \
|
|
570
|
+
K, \
|
|
571
|
+
cent_distances_i, \
|
|
572
|
+
new_beam_size, \
|
|
573
|
+
new_distances_i, \
|
|
574
|
+
perm.data()); \
|
|
575
|
+
break;
|
|
576
|
+
|
|
577
|
+
switch (approx_topk_mode) {
|
|
578
|
+
HANDLE_APPROX(8, 3)
|
|
579
|
+
HANDLE_APPROX(8, 2)
|
|
580
|
+
HANDLE_APPROX(16, 2)
|
|
581
|
+
HANDLE_APPROX(32, 2)
|
|
582
|
+
default:
|
|
583
|
+
heap_addn<C>(
|
|
584
|
+
new_beam_size,
|
|
585
|
+
new_distances_i,
|
|
586
|
+
perm.data(),
|
|
587
|
+
cent_distances_i,
|
|
588
|
+
nullptr,
|
|
589
|
+
beam_size * K);
|
|
590
|
+
break;
|
|
591
|
+
}
|
|
592
|
+
|
|
593
|
+
heap_reorder<C>(new_beam_size, new_distances_i, perm.data());
|
|
594
|
+
|
|
595
|
+
#undef HANDLE_APPROX
|
|
596
|
+
|
|
597
|
+
for (int j = 0; j < new_beam_size; j++) {
|
|
598
|
+
int js = perm[j] / K;
|
|
599
|
+
int ls = perm[j] % K;
|
|
600
|
+
if (m > 0) {
|
|
601
|
+
memcpy(new_codes_i, codes_i + js * m, sizeof(*codes) * m);
|
|
602
|
+
}
|
|
603
|
+
new_codes_i[m] = ls;
|
|
604
|
+
new_codes_i += m + 1;
|
|
605
|
+
}
|
|
606
|
+
}
|
|
607
|
+
}
|
|
608
|
+
|
|
609
|
+
/********************************************************************
|
|
610
|
+
* Multiple encoding steps
|
|
611
|
+
********************************************************************/
|
|
612
|
+
|
|
613
|
+
namespace rq_encode_steps {
|
|
614
|
+
|
|
615
|
+
void refine_beam_mp(
|
|
616
|
+
const ResidualQuantizer& rq,
|
|
617
|
+
size_t n,
|
|
618
|
+
size_t beam_size,
|
|
619
|
+
const float* x,
|
|
620
|
+
int out_beam_size,
|
|
621
|
+
int32_t* out_codes,
|
|
622
|
+
float* out_residuals,
|
|
623
|
+
float* out_distances,
|
|
624
|
+
RefineBeamMemoryPool& pool) {
|
|
625
|
+
int cur_beam_size = beam_size;
|
|
626
|
+
|
|
627
|
+
double t0 = getmillisecs();
|
|
628
|
+
|
|
629
|
+
// find the max_beam_size
|
|
630
|
+
int max_beam_size = 0;
|
|
631
|
+
{
|
|
632
|
+
int tmp_beam_size = cur_beam_size;
|
|
633
|
+
for (int m = 0; m < rq.M; m++) {
|
|
634
|
+
int K = 1 << rq.nbits[m];
|
|
635
|
+
int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
|
|
636
|
+
tmp_beam_size = new_beam_size;
|
|
637
|
+
|
|
638
|
+
if (max_beam_size < new_beam_size) {
|
|
639
|
+
max_beam_size = new_beam_size;
|
|
640
|
+
}
|
|
641
|
+
}
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
// preallocate buffers
|
|
645
|
+
pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
|
|
646
|
+
pool.new_residuals.resize(n * max_beam_size * rq.d);
|
|
647
|
+
|
|
648
|
+
pool.codes.resize(n * max_beam_size * (rq.M + 1));
|
|
649
|
+
pool.distances.resize(n * max_beam_size);
|
|
650
|
+
pool.residuals.resize(n * rq.d * max_beam_size);
|
|
651
|
+
|
|
652
|
+
for (size_t i = 0; i < n * rq.d * beam_size; i++) {
|
|
653
|
+
pool.residuals[i] = x[i];
|
|
654
|
+
}
|
|
655
|
+
|
|
656
|
+
// set up pointers to buffers
|
|
657
|
+
int32_t* __restrict codes_ptr = pool.codes.data();
|
|
658
|
+
float* __restrict residuals_ptr = pool.residuals.data();
|
|
659
|
+
|
|
660
|
+
int32_t* __restrict new_codes_ptr = pool.new_codes.data();
|
|
661
|
+
float* __restrict new_residuals_ptr = pool.new_residuals.data();
|
|
662
|
+
|
|
663
|
+
// index
|
|
664
|
+
std::unique_ptr<Index> assign_index;
|
|
665
|
+
if (rq.assign_index_factory) {
|
|
666
|
+
assign_index.reset((*rq.assign_index_factory)(rq.d));
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
// main loop
|
|
670
|
+
size_t codes_size = 0;
|
|
671
|
+
size_t distances_size = 0;
|
|
672
|
+
size_t residuals_size = 0;
|
|
673
|
+
|
|
674
|
+
for (int m = 0; m < rq.M; m++) {
|
|
675
|
+
int K = 1 << rq.nbits[m];
|
|
676
|
+
|
|
677
|
+
const float* __restrict codebooks_m =
|
|
678
|
+
rq.codebooks.data() + rq.codebook_offsets[m] * rq.d;
|
|
679
|
+
|
|
680
|
+
const int new_beam_size = std::min(cur_beam_size * K, out_beam_size);
|
|
681
|
+
|
|
682
|
+
codes_size = n * new_beam_size * (m + 1);
|
|
683
|
+
residuals_size = n * new_beam_size * rq.d;
|
|
684
|
+
distances_size = n * new_beam_size;
|
|
685
|
+
|
|
686
|
+
beam_search_encode_step(
|
|
687
|
+
rq.d,
|
|
688
|
+
K,
|
|
689
|
+
codebooks_m,
|
|
690
|
+
n,
|
|
691
|
+
cur_beam_size,
|
|
692
|
+
residuals_ptr,
|
|
693
|
+
m,
|
|
694
|
+
codes_ptr,
|
|
695
|
+
new_beam_size,
|
|
696
|
+
new_codes_ptr,
|
|
697
|
+
new_residuals_ptr,
|
|
698
|
+
pool.distances.data(),
|
|
699
|
+
assign_index.get(),
|
|
700
|
+
rq.approx_topk_mode);
|
|
701
|
+
|
|
702
|
+
if (assign_index != nullptr) {
|
|
703
|
+
assign_index->reset();
|
|
704
|
+
}
|
|
705
|
+
|
|
706
|
+
std::swap(codes_ptr, new_codes_ptr);
|
|
707
|
+
std::swap(residuals_ptr, new_residuals_ptr);
|
|
708
|
+
|
|
709
|
+
cur_beam_size = new_beam_size;
|
|
710
|
+
|
|
711
|
+
if (rq.verbose) {
|
|
712
|
+
float sum_distances = 0;
|
|
713
|
+
for (int j = 0; j < distances_size; j++) {
|
|
714
|
+
sum_distances += pool.distances[j];
|
|
715
|
+
}
|
|
716
|
+
|
|
717
|
+
printf("[%.3f s] encode stage %d, %d bits, "
|
|
718
|
+
"total error %g, beam_size %d\n",
|
|
719
|
+
(getmillisecs() - t0) / 1000,
|
|
720
|
+
m,
|
|
721
|
+
int(rq.nbits[m]),
|
|
722
|
+
sum_distances,
|
|
723
|
+
cur_beam_size);
|
|
724
|
+
}
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
if (out_codes) {
|
|
728
|
+
memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
|
|
729
|
+
}
|
|
730
|
+
if (out_residuals) {
|
|
731
|
+
memcpy(out_residuals,
|
|
732
|
+
residuals_ptr,
|
|
733
|
+
residuals_size * sizeof(*residuals_ptr));
|
|
734
|
+
}
|
|
735
|
+
if (out_distances) {
|
|
736
|
+
memcpy(out_distances,
|
|
737
|
+
pool.distances.data(),
|
|
738
|
+
distances_size * sizeof(pool.distances[0]));
|
|
739
|
+
}
|
|
740
|
+
}
|
|
741
|
+
|
|
742
|
+
void refine_beam_LUT_mp(
|
|
743
|
+
const ResidualQuantizer& rq,
|
|
744
|
+
size_t n,
|
|
745
|
+
const float* query_norms, // size n
|
|
746
|
+
const float* query_cp, //
|
|
747
|
+
int out_beam_size,
|
|
748
|
+
int32_t* out_codes,
|
|
749
|
+
float* out_distances,
|
|
750
|
+
RefineBeamLUTMemoryPool& pool) {
|
|
751
|
+
int beam_size = 1;
|
|
752
|
+
|
|
753
|
+
double t0 = getmillisecs();
|
|
754
|
+
|
|
755
|
+
// find the max_beam_size
|
|
756
|
+
int max_beam_size = 0;
|
|
757
|
+
{
|
|
758
|
+
int tmp_beam_size = beam_size;
|
|
759
|
+
for (int m = 0; m < rq.M; m++) {
|
|
760
|
+
int K = 1 << rq.nbits[m];
|
|
761
|
+
int new_beam_size = std::min(tmp_beam_size * K, out_beam_size);
|
|
762
|
+
tmp_beam_size = new_beam_size;
|
|
763
|
+
|
|
764
|
+
if (max_beam_size < new_beam_size) {
|
|
765
|
+
max_beam_size = new_beam_size;
|
|
766
|
+
}
|
|
767
|
+
}
|
|
768
|
+
}
|
|
769
|
+
|
|
770
|
+
// preallocate buffers
|
|
771
|
+
pool.new_codes.resize(n * max_beam_size * (rq.M + 1));
|
|
772
|
+
pool.new_distances.resize(n * max_beam_size);
|
|
773
|
+
|
|
774
|
+
pool.codes.resize(n * max_beam_size * (rq.M + 1));
|
|
775
|
+
pool.distances.resize(n * max_beam_size);
|
|
776
|
+
|
|
777
|
+
for (size_t i = 0; i < n; i++) {
|
|
778
|
+
pool.distances[i] = query_norms[i];
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
// set up pointers to buffers
|
|
782
|
+
int32_t* __restrict new_codes_ptr = pool.new_codes.data();
|
|
783
|
+
float* __restrict new_distances_ptr = pool.new_distances.data();
|
|
784
|
+
|
|
785
|
+
int32_t* __restrict codes_ptr = pool.codes.data();
|
|
786
|
+
float* __restrict distances_ptr = pool.distances.data();
|
|
787
|
+
|
|
788
|
+
// main loop
|
|
789
|
+
size_t codes_size = 0;
|
|
790
|
+
size_t distances_size = 0;
|
|
791
|
+
size_t cross_ofs = 0;
|
|
792
|
+
for (int m = 0; m < rq.M; m++) {
|
|
793
|
+
int K = 1 << rq.nbits[m];
|
|
794
|
+
|
|
795
|
+
// it is guaranteed that (new_beam_size <= max_beam_size)
|
|
796
|
+
int new_beam_size = std::min(beam_size * K, out_beam_size);
|
|
797
|
+
|
|
798
|
+
codes_size = n * new_beam_size * (m + 1);
|
|
799
|
+
distances_size = n * new_beam_size;
|
|
800
|
+
FAISS_THROW_IF_NOT(
|
|
801
|
+
cross_ofs + rq.codebook_offsets[m] * K <=
|
|
802
|
+
rq.codebook_cross_products.size());
|
|
803
|
+
beam_search_encode_step_tab(
|
|
804
|
+
K,
|
|
805
|
+
n,
|
|
806
|
+
beam_size,
|
|
807
|
+
rq.codebook_cross_products.data() + cross_ofs,
|
|
808
|
+
K,
|
|
809
|
+
rq.codebook_offsets.data(),
|
|
810
|
+
query_cp + rq.codebook_offsets[m],
|
|
811
|
+
rq.total_codebook_size,
|
|
812
|
+
rq.cent_norms.data() + rq.codebook_offsets[m],
|
|
813
|
+
m,
|
|
814
|
+
codes_ptr,
|
|
815
|
+
distances_ptr,
|
|
816
|
+
new_beam_size,
|
|
817
|
+
new_codes_ptr,
|
|
818
|
+
new_distances_ptr,
|
|
819
|
+
rq.approx_topk_mode);
|
|
820
|
+
cross_ofs += rq.codebook_offsets[m] * K;
|
|
821
|
+
std::swap(codes_ptr, new_codes_ptr);
|
|
822
|
+
std::swap(distances_ptr, new_distances_ptr);
|
|
823
|
+
|
|
824
|
+
beam_size = new_beam_size;
|
|
825
|
+
|
|
826
|
+
if (rq.verbose) {
|
|
827
|
+
float sum_distances = 0;
|
|
828
|
+
for (int j = 0; j < distances_size; j++) {
|
|
829
|
+
sum_distances += distances_ptr[j];
|
|
830
|
+
}
|
|
831
|
+
printf("[%.3f s] encode stage %d, %d bits, "
|
|
832
|
+
"total error %g, beam_size %d\n",
|
|
833
|
+
(getmillisecs() - t0) / 1000,
|
|
834
|
+
m,
|
|
835
|
+
int(rq.nbits[m]),
|
|
836
|
+
sum_distances,
|
|
837
|
+
beam_size);
|
|
838
|
+
}
|
|
839
|
+
}
|
|
840
|
+
if (out_codes) {
|
|
841
|
+
memcpy(out_codes, codes_ptr, codes_size * sizeof(*codes_ptr));
|
|
842
|
+
}
|
|
843
|
+
if (out_distances) {
|
|
844
|
+
memcpy(out_distances,
|
|
845
|
+
distances_ptr,
|
|
846
|
+
distances_size * sizeof(*distances_ptr));
|
|
847
|
+
}
|
|
848
|
+
}
|
|
849
|
+
|
|
850
|
+
// this is for use_beam_LUT == 0
|
|
851
|
+
void compute_codes_add_centroids_mp_lut0(
|
|
852
|
+
const ResidualQuantizer& rq,
|
|
853
|
+
const float* x,
|
|
854
|
+
uint8_t* codes_out,
|
|
855
|
+
size_t n,
|
|
856
|
+
const float* centroids,
|
|
857
|
+
ComputeCodesAddCentroidsLUT0MemoryPool& pool) {
|
|
858
|
+
pool.codes.resize(rq.max_beam_size * rq.M * n);
|
|
859
|
+
pool.distances.resize(rq.max_beam_size * n);
|
|
860
|
+
|
|
861
|
+
pool.residuals.resize(rq.max_beam_size * n * rq.d);
|
|
862
|
+
|
|
863
|
+
refine_beam_mp(
|
|
864
|
+
rq,
|
|
865
|
+
n,
|
|
866
|
+
1,
|
|
867
|
+
x,
|
|
868
|
+
rq.max_beam_size,
|
|
869
|
+
pool.codes.data(),
|
|
870
|
+
pool.residuals.data(),
|
|
871
|
+
pool.distances.data(),
|
|
872
|
+
pool.refine_beam_pool);
|
|
873
|
+
|
|
874
|
+
if (rq.search_type == ResidualQuantizer::ST_norm_float ||
|
|
875
|
+
rq.search_type == ResidualQuantizer::ST_norm_qint8 ||
|
|
876
|
+
rq.search_type == ResidualQuantizer::ST_norm_qint4) {
|
|
877
|
+
pool.norms.resize(n);
|
|
878
|
+
// recover the norms of reconstruction as
|
|
879
|
+
// || original_vector - residual ||^2
|
|
880
|
+
for (size_t i = 0; i < n; i++) {
|
|
881
|
+
pool.norms[i] = fvec_L2sqr(
|
|
882
|
+
x + i * rq.d,
|
|
883
|
+
pool.residuals.data() + i * rq.max_beam_size * rq.d,
|
|
884
|
+
rq.d);
|
|
885
|
+
}
|
|
886
|
+
}
|
|
887
|
+
|
|
888
|
+
// pack only the first code of the beam
|
|
889
|
+
// (hence the ld_codes=M * max_beam_size)
|
|
890
|
+
rq.pack_codes(
|
|
891
|
+
n,
|
|
892
|
+
pool.codes.data(),
|
|
893
|
+
codes_out,
|
|
894
|
+
rq.M * rq.max_beam_size,
|
|
895
|
+
(pool.norms.size() > 0) ? pool.norms.data() : nullptr,
|
|
896
|
+
centroids);
|
|
897
|
+
}
|
|
898
|
+
|
|
899
|
+
// use_beam_LUT == 1
|
|
900
|
+
void compute_codes_add_centroids_mp_lut1(
|
|
901
|
+
const ResidualQuantizer& rq,
|
|
902
|
+
const float* x,
|
|
903
|
+
uint8_t* codes_out,
|
|
904
|
+
size_t n,
|
|
905
|
+
const float* centroids,
|
|
906
|
+
ComputeCodesAddCentroidsLUT1MemoryPool& pool) {
|
|
907
|
+
//
|
|
908
|
+
pool.codes.resize(rq.max_beam_size * rq.M * n);
|
|
909
|
+
pool.distances.resize(rq.max_beam_size * n);
|
|
910
|
+
|
|
911
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
912
|
+
rq.M == 1 || rq.codebook_cross_products.size() > 0,
|
|
913
|
+
"call compute_codebook_tables first");
|
|
914
|
+
|
|
915
|
+
pool.query_norms.resize(n);
|
|
916
|
+
fvec_norms_L2sqr(pool.query_norms.data(), x, rq.d, n);
|
|
917
|
+
|
|
918
|
+
pool.query_cp.resize(n * rq.total_codebook_size);
|
|
919
|
+
{
|
|
920
|
+
FINTEGER ti = rq.total_codebook_size, di = rq.d, ni = n;
|
|
921
|
+
float zero = 0, one = 1;
|
|
922
|
+
sgemm_("Transposed",
|
|
923
|
+
"Not transposed",
|
|
924
|
+
&ti,
|
|
925
|
+
&ni,
|
|
926
|
+
&di,
|
|
927
|
+
&one,
|
|
928
|
+
rq.codebooks.data(),
|
|
929
|
+
&di,
|
|
930
|
+
x,
|
|
931
|
+
&di,
|
|
932
|
+
&zero,
|
|
933
|
+
pool.query_cp.data(),
|
|
934
|
+
&ti);
|
|
935
|
+
}
|
|
936
|
+
|
|
937
|
+
refine_beam_LUT_mp(
|
|
938
|
+
rq,
|
|
939
|
+
n,
|
|
940
|
+
pool.query_norms.data(),
|
|
941
|
+
pool.query_cp.data(),
|
|
942
|
+
rq.max_beam_size,
|
|
943
|
+
pool.codes.data(),
|
|
944
|
+
pool.distances.data(),
|
|
945
|
+
pool.refine_beam_lut_pool);
|
|
946
|
+
|
|
947
|
+
// pack only the first code of the beam
|
|
948
|
+
// (hence the ld_codes=M * max_beam_size)
|
|
949
|
+
rq.pack_codes(
|
|
950
|
+
n,
|
|
951
|
+
pool.codes.data(),
|
|
952
|
+
codes_out,
|
|
953
|
+
rq.M * rq.max_beam_size,
|
|
954
|
+
nullptr,
|
|
955
|
+
centroids);
|
|
956
|
+
}
|
|
957
|
+
|
|
958
|
+
} // namespace rq_encode_steps
|
|
959
|
+
|
|
960
|
+
} // namespace faiss
|