faiss 0.2.3 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
- data/vendor/faiss/faiss/Index2Layer.h +8 -17
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
- data/vendor/faiss/faiss/IndexFlat.h +16 -19
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
- data/vendor/faiss/faiss/IndexIVF.h +59 -22
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
- data/vendor/faiss/faiss/IndexLSH.h +4 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
- data/vendor/faiss/faiss/IndexPQ.h +21 -22
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
- data/vendor/faiss/faiss/IndexRefine.h +14 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
- data/vendor/faiss/faiss/VectorTransform.h +25 -4
- data/vendor/faiss/faiss/clone_index.cpp +26 -3
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
- data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +772 -412
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +384 -58
- data/vendor/faiss/faiss/utils/distances.h +149 -18
- data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +46 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
- data/vendor/faiss/faiss/IndexResidual.h +0 -152
|
@@ -70,10 +70,11 @@ bool getTensorCoreSupport(int device);
|
|
|
70
70
|
/// Equivalent to getTensorCoreSupport(getCurrentDevice())
|
|
71
71
|
bool getTensorCoreSupportCurrentDevice();
|
|
72
72
|
|
|
73
|
-
/// Returns the
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
73
|
+
/// Returns the amount of currently available memory on the given device
|
|
74
|
+
size_t getFreeMemory(int device);
|
|
75
|
+
|
|
76
|
+
/// Equivalent to getFreeMemory(getCurrentDevice())
|
|
77
|
+
size_t getFreeMemoryCurrentDevice();
|
|
77
78
|
|
|
78
79
|
/// RAII object to set the current device, and restore the previous
|
|
79
80
|
/// device upon destruction
|
|
@@ -8,7 +8,6 @@
|
|
|
8
8
|
// -*- c++ -*-
|
|
9
9
|
|
|
10
10
|
#include <faiss/impl/AdditiveQuantizer.h>
|
|
11
|
-
#include <faiss/impl/FaissAssert.h>
|
|
12
11
|
|
|
13
12
|
#include <cstddef>
|
|
14
13
|
#include <cstdio>
|
|
@@ -18,9 +17,13 @@
|
|
|
18
17
|
|
|
19
18
|
#include <algorithm>
|
|
20
19
|
|
|
20
|
+
#include <faiss/Clustering.h>
|
|
21
|
+
#include <faiss/impl/FaissAssert.h>
|
|
22
|
+
#include <faiss/impl/LocalSearchQuantizer.h>
|
|
23
|
+
#include <faiss/impl/ResidualQuantizer.h>
|
|
21
24
|
#include <faiss/utils/Heap.h>
|
|
22
25
|
#include <faiss/utils/distances.h>
|
|
23
|
-
#include <faiss/utils/hamming.h>
|
|
26
|
+
#include <faiss/utils/hamming.h>
|
|
24
27
|
#include <faiss/utils/utils.h>
|
|
25
28
|
|
|
26
29
|
extern "C" {
|
|
@@ -42,51 +45,211 @@ int sgemm_(
|
|
|
42
45
|
FINTEGER* ldc);
|
|
43
46
|
}
|
|
44
47
|
|
|
45
|
-
namespace {
|
|
46
|
-
|
|
47
|
-
// c and a and b can overlap
|
|
48
|
-
void fvec_add(size_t d, const float* a, const float* b, float* c) {
|
|
49
|
-
for (size_t i = 0; i < d; i++) {
|
|
50
|
-
c[i] = a[i] + b[i];
|
|
51
|
-
}
|
|
52
|
-
}
|
|
48
|
+
namespace faiss {
|
|
53
49
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
50
|
+
AdditiveQuantizer::AdditiveQuantizer(
|
|
51
|
+
size_t d,
|
|
52
|
+
const std::vector<size_t>& nbits,
|
|
53
|
+
Search_type_t search_type)
|
|
54
|
+
: Quantizer(d),
|
|
55
|
+
M(nbits.size()),
|
|
56
|
+
nbits(nbits),
|
|
57
|
+
verbose(false),
|
|
58
|
+
is_trained(false),
|
|
59
|
+
max_mem_distances(5 * (size_t(1) << 30)), // 5 GiB
|
|
60
|
+
search_type(search_type) {
|
|
61
|
+
norm_max = norm_min = NAN;
|
|
62
|
+
tot_bits = 0;
|
|
63
|
+
total_codebook_size = 0;
|
|
64
|
+
only_8bit = false;
|
|
65
|
+
set_derived_values();
|
|
58
66
|
}
|
|
59
67
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
namespace faiss {
|
|
68
|
+
AdditiveQuantizer::AdditiveQuantizer()
|
|
69
|
+
: AdditiveQuantizer(0, std::vector<size_t>()) {}
|
|
63
70
|
|
|
64
71
|
void AdditiveQuantizer::set_derived_values() {
|
|
65
72
|
tot_bits = 0;
|
|
66
|
-
|
|
73
|
+
only_8bit = true;
|
|
67
74
|
codebook_offsets.resize(M + 1, 0);
|
|
68
75
|
for (int i = 0; i < M; i++) {
|
|
69
76
|
int nbit = nbits[i];
|
|
70
77
|
size_t k = 1 << nbit;
|
|
71
78
|
codebook_offsets[i + 1] = codebook_offsets[i] + k;
|
|
72
79
|
tot_bits += nbit;
|
|
73
|
-
if (nbit
|
|
74
|
-
|
|
80
|
+
if (nbit != 0) {
|
|
81
|
+
only_8bit = false;
|
|
75
82
|
}
|
|
76
83
|
}
|
|
77
84
|
total_codebook_size = codebook_offsets[M];
|
|
85
|
+
switch (search_type) {
|
|
86
|
+
case ST_norm_float:
|
|
87
|
+
norm_bits = 32;
|
|
88
|
+
break;
|
|
89
|
+
case ST_norm_qint8:
|
|
90
|
+
case ST_norm_cqint8:
|
|
91
|
+
case ST_norm_lsq2x4:
|
|
92
|
+
case ST_norm_rq2x4:
|
|
93
|
+
norm_bits = 8;
|
|
94
|
+
break;
|
|
95
|
+
case ST_norm_qint4:
|
|
96
|
+
case ST_norm_cqint4:
|
|
97
|
+
norm_bits = 4;
|
|
98
|
+
break;
|
|
99
|
+
case ST_decompress:
|
|
100
|
+
case ST_LUT_nonorm:
|
|
101
|
+
case ST_norm_from_LUT:
|
|
102
|
+
default:
|
|
103
|
+
norm_bits = 0;
|
|
104
|
+
break;
|
|
105
|
+
}
|
|
106
|
+
tot_bits += norm_bits;
|
|
107
|
+
|
|
78
108
|
// convert bits to bytes
|
|
79
109
|
code_size = (tot_bits + 7) / 8;
|
|
80
110
|
}
|
|
81
111
|
|
|
112
|
+
void AdditiveQuantizer::train_norm(size_t n, const float* norms) {
|
|
113
|
+
norm_min = HUGE_VALF;
|
|
114
|
+
norm_max = -HUGE_VALF;
|
|
115
|
+
for (idx_t i = 0; i < n; i++) {
|
|
116
|
+
if (norms[i] < norm_min) {
|
|
117
|
+
norm_min = norms[i];
|
|
118
|
+
}
|
|
119
|
+
if (norms[i] > norm_max) {
|
|
120
|
+
norm_max = norms[i];
|
|
121
|
+
}
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
if (search_type == ST_norm_cqint8 || search_type == ST_norm_cqint4) {
|
|
125
|
+
size_t k = (1 << 8);
|
|
126
|
+
if (search_type == ST_norm_cqint4) {
|
|
127
|
+
k = (1 << 4);
|
|
128
|
+
}
|
|
129
|
+
Clustering1D clus(k);
|
|
130
|
+
clus.train_exact(n, norms);
|
|
131
|
+
qnorm.add(clus.k, clus.centroids.data());
|
|
132
|
+
} else if (search_type == ST_norm_lsq2x4 || search_type == ST_norm_rq2x4) {
|
|
133
|
+
std::unique_ptr<AdditiveQuantizer> aq;
|
|
134
|
+
if (search_type == ST_norm_lsq2x4) {
|
|
135
|
+
aq.reset(new LocalSearchQuantizer(1, 2, 4));
|
|
136
|
+
} else {
|
|
137
|
+
aq.reset(new ResidualQuantizer(1, 2, 4));
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
aq->train(n, norms);
|
|
141
|
+
// flatten aq codebooks
|
|
142
|
+
std::vector<float> flat_codebooks(1 << 8);
|
|
143
|
+
FAISS_THROW_IF_NOT(aq->codebooks.size() == 32);
|
|
144
|
+
|
|
145
|
+
// save norm tables for 4-bit fastscan search
|
|
146
|
+
norm_tabs = aq->codebooks;
|
|
147
|
+
|
|
148
|
+
// assume big endian
|
|
149
|
+
const float* c = norm_tabs.data();
|
|
150
|
+
for (size_t i = 0; i < 16; i++) {
|
|
151
|
+
for (size_t j = 0; j < 16; j++) {
|
|
152
|
+
flat_codebooks[i * 16 + j] = c[j] + c[16 + i];
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
qnorm.reset();
|
|
157
|
+
qnorm.add(1 << 8, flat_codebooks.data());
|
|
158
|
+
FAISS_THROW_IF_NOT(qnorm.ntotal == (1 << 8));
|
|
159
|
+
}
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
namespace {
|
|
163
|
+
|
|
164
|
+
// TODO
|
|
165
|
+
// https://stackoverflow.com/questions/31631224/hacks-for-clamping-integer-to-0-255-and-doubles-to-0-0-1-0
|
|
166
|
+
|
|
167
|
+
uint8_t encode_qint8(float x, float amin, float amax) {
|
|
168
|
+
float x1 = (x - amin) / (amax - amin) * 256;
|
|
169
|
+
int32_t xi = int32_t(floor(x1));
|
|
170
|
+
|
|
171
|
+
return xi < 0 ? 0 : xi > 255 ? 255 : xi;
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
uint8_t encode_qint4(float x, float amin, float amax) {
|
|
175
|
+
float x1 = (x - amin) / (amax - amin) * 16;
|
|
176
|
+
int32_t xi = int32_t(floor(x1));
|
|
177
|
+
|
|
178
|
+
return xi < 0 ? 0 : xi > 15 ? 15 : xi;
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
float decode_qint8(uint8_t i, float amin, float amax) {
|
|
182
|
+
return (i + 0.5) / 256 * (amax - amin) + amin;
|
|
183
|
+
}
|
|
184
|
+
|
|
185
|
+
float decode_qint4(uint8_t i, float amin, float amax) {
|
|
186
|
+
return (i + 0.5) / 16 * (amax - amin) + amin;
|
|
187
|
+
}
|
|
188
|
+
|
|
189
|
+
} // anonymous namespace
|
|
190
|
+
|
|
191
|
+
uint32_t AdditiveQuantizer::encode_qcint(float x) const {
|
|
192
|
+
idx_t id;
|
|
193
|
+
qnorm.assign(1, &x, &id, 1);
|
|
194
|
+
return uint32_t(id);
|
|
195
|
+
}
|
|
196
|
+
|
|
197
|
+
float AdditiveQuantizer::decode_qcint(uint32_t c) const {
|
|
198
|
+
return qnorm.get_xb()[c];
|
|
199
|
+
}
|
|
200
|
+
|
|
201
|
+
uint64_t AdditiveQuantizer::encode_norm(float norm) const {
|
|
202
|
+
switch (search_type) {
|
|
203
|
+
case ST_norm_float:
|
|
204
|
+
uint32_t inorm;
|
|
205
|
+
memcpy(&inorm, &norm, 4);
|
|
206
|
+
return inorm;
|
|
207
|
+
case ST_norm_qint8:
|
|
208
|
+
return encode_qint8(norm, norm_min, norm_max);
|
|
209
|
+
case ST_norm_qint4:
|
|
210
|
+
return encode_qint4(norm, norm_min, norm_max);
|
|
211
|
+
case ST_norm_lsq2x4:
|
|
212
|
+
case ST_norm_rq2x4:
|
|
213
|
+
case ST_norm_cqint8:
|
|
214
|
+
return encode_qcint(norm);
|
|
215
|
+
case ST_norm_cqint4:
|
|
216
|
+
return encode_qcint(norm);
|
|
217
|
+
case ST_decompress:
|
|
218
|
+
case ST_LUT_nonorm:
|
|
219
|
+
case ST_norm_from_LUT:
|
|
220
|
+
default:
|
|
221
|
+
return 0;
|
|
222
|
+
}
|
|
223
|
+
}
|
|
224
|
+
|
|
82
225
|
void AdditiveQuantizer::pack_codes(
|
|
83
226
|
size_t n,
|
|
84
227
|
const int32_t* codes,
|
|
85
228
|
uint8_t* packed_codes,
|
|
86
|
-
int64_t ld_codes
|
|
229
|
+
int64_t ld_codes,
|
|
230
|
+
const float* norms,
|
|
231
|
+
const float* centroids) const {
|
|
87
232
|
if (ld_codes == -1) {
|
|
88
233
|
ld_codes = M;
|
|
89
234
|
}
|
|
235
|
+
std::vector<float> norm_buf;
|
|
236
|
+
if (search_type == ST_norm_float || search_type == ST_norm_qint4 ||
|
|
237
|
+
search_type == ST_norm_qint8 || search_type == ST_norm_cqint8 ||
|
|
238
|
+
search_type == ST_norm_cqint4 || search_type == ST_norm_lsq2x4 ||
|
|
239
|
+
search_type == ST_norm_rq2x4) {
|
|
240
|
+
if (centroids != nullptr || !norms) {
|
|
241
|
+
norm_buf.resize(n);
|
|
242
|
+
std::vector<float> x_recons(n * d);
|
|
243
|
+
decode_unpacked(codes, x_recons.data(), n, ld_codes);
|
|
244
|
+
|
|
245
|
+
if (centroids != nullptr) {
|
|
246
|
+
// x = x + c
|
|
247
|
+
fvec_add(n * d, x_recons.data(), centroids, x_recons.data());
|
|
248
|
+
}
|
|
249
|
+
fvec_norms_L2sqr(norm_buf.data(), x_recons.data(), d, n);
|
|
250
|
+
norms = norm_buf.data();
|
|
251
|
+
}
|
|
252
|
+
}
|
|
90
253
|
#pragma omp parallel for if (n > 1000)
|
|
91
254
|
for (int64_t i = 0; i < n; i++) {
|
|
92
255
|
const int32_t* codes1 = codes + i * ld_codes;
|
|
@@ -94,6 +257,9 @@ void AdditiveQuantizer::pack_codes(
|
|
|
94
257
|
for (int m = 0; m < M; m++) {
|
|
95
258
|
bsw.write(codes1[m], nbits[m]);
|
|
96
259
|
}
|
|
260
|
+
if (norm_bits != 0) {
|
|
261
|
+
bsw.write(encode_norm(norms[i]), norm_bits);
|
|
262
|
+
}
|
|
97
263
|
}
|
|
98
264
|
}
|
|
99
265
|
|
|
@@ -118,10 +284,39 @@ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
|
|
|
118
284
|
}
|
|
119
285
|
}
|
|
120
286
|
|
|
287
|
+
void AdditiveQuantizer::decode_unpacked(
|
|
288
|
+
const int32_t* code,
|
|
289
|
+
float* x,
|
|
290
|
+
size_t n,
|
|
291
|
+
int64_t ld_codes) const {
|
|
292
|
+
FAISS_THROW_IF_NOT_MSG(
|
|
293
|
+
is_trained, "The additive quantizer is not trained yet.");
|
|
294
|
+
|
|
295
|
+
if (ld_codes == -1) {
|
|
296
|
+
ld_codes = M;
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
// standard additive quantizer decoding
|
|
300
|
+
#pragma omp parallel for if (n > 1000)
|
|
301
|
+
for (int64_t i = 0; i < n; i++) {
|
|
302
|
+
const int32_t* codesi = code + i * ld_codes;
|
|
303
|
+
float* xi = x + i * d;
|
|
304
|
+
for (int m = 0; m < M; m++) {
|
|
305
|
+
int idx = codesi[m];
|
|
306
|
+
const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
|
|
307
|
+
if (m == 0) {
|
|
308
|
+
memcpy(xi, c, sizeof(*x) * d);
|
|
309
|
+
} else {
|
|
310
|
+
fvec_add(d, xi, c, xi);
|
|
311
|
+
}
|
|
312
|
+
}
|
|
313
|
+
}
|
|
314
|
+
}
|
|
315
|
+
|
|
121
316
|
AdditiveQuantizer::~AdditiveQuantizer() {}
|
|
122
317
|
|
|
123
318
|
/****************************************************************************
|
|
124
|
-
* Support for fast distance computations
|
|
319
|
+
* Support for fast distance computations in centroids
|
|
125
320
|
****************************************************************************/
|
|
126
321
|
|
|
127
322
|
void AdditiveQuantizer::compute_centroid_norms(float* norms) const {
|
|
@@ -151,28 +346,33 @@ void AdditiveQuantizer::decode_64bit(idx_t bits, float* xi) const {
|
|
|
151
346
|
}
|
|
152
347
|
}
|
|
153
348
|
|
|
154
|
-
void AdditiveQuantizer::compute_LUT(
|
|
155
|
-
|
|
349
|
+
void AdditiveQuantizer::compute_LUT(
|
|
350
|
+
size_t n,
|
|
351
|
+
const float* xq,
|
|
352
|
+
float* LUT,
|
|
353
|
+
float alpha,
|
|
354
|
+
long ld_lut) const {
|
|
156
355
|
// in all cases, it is large matrix multiplication
|
|
157
356
|
|
|
158
357
|
FINTEGER ncenti = total_codebook_size;
|
|
159
358
|
FINTEGER di = d;
|
|
160
359
|
FINTEGER nqi = n;
|
|
161
|
-
|
|
360
|
+
FINTEGER ldc = ld_lut > 0 ? ld_lut : ncenti;
|
|
361
|
+
float zero = 0;
|
|
162
362
|
|
|
163
363
|
sgemm_("Transposed",
|
|
164
364
|
"Not transposed",
|
|
165
365
|
&ncenti,
|
|
166
366
|
&nqi,
|
|
167
367
|
&di,
|
|
168
|
-
&
|
|
368
|
+
&alpha,
|
|
169
369
|
codebooks.data(),
|
|
170
370
|
&di,
|
|
171
371
|
xq,
|
|
172
372
|
&di,
|
|
173
373
|
&zero,
|
|
174
374
|
LUT,
|
|
175
|
-
&
|
|
375
|
+
&ldc);
|
|
176
376
|
}
|
|
177
377
|
|
|
178
378
|
namespace {
|
|
@@ -201,7 +401,7 @@ void compute_inner_prod_with_LUT(
|
|
|
201
401
|
|
|
202
402
|
} // anonymous namespace
|
|
203
403
|
|
|
204
|
-
void AdditiveQuantizer::
|
|
404
|
+
void AdditiveQuantizer::knn_centroids_inner_product(
|
|
205
405
|
idx_t n,
|
|
206
406
|
const float* xq,
|
|
207
407
|
idx_t k,
|
|
@@ -227,7 +427,7 @@ void AdditiveQuantizer::knn_exact_inner_product(
|
|
|
227
427
|
}
|
|
228
428
|
}
|
|
229
429
|
|
|
230
|
-
void AdditiveQuantizer::
|
|
430
|
+
void AdditiveQuantizer::knn_centroids_L2(
|
|
231
431
|
idx_t n,
|
|
232
432
|
const float* xq,
|
|
233
433
|
idx_t k,
|
|
@@ -267,4 +467,106 @@ void AdditiveQuantizer::knn_exact_L2(
|
|
|
267
467
|
}
|
|
268
468
|
}
|
|
269
469
|
|
|
470
|
+
/****************************************************************************
|
|
471
|
+
* Support for fast distance computations in codes
|
|
472
|
+
****************************************************************************/
|
|
473
|
+
|
|
474
|
+
namespace {
|
|
475
|
+
|
|
476
|
+
float accumulate_IPs(
|
|
477
|
+
const AdditiveQuantizer& aq,
|
|
478
|
+
BitstringReader& bs,
|
|
479
|
+
const uint8_t* codes,
|
|
480
|
+
const float* LUT) {
|
|
481
|
+
float accu = 0;
|
|
482
|
+
for (int m = 0; m < aq.M; m++) {
|
|
483
|
+
size_t nbit = aq.nbits[m];
|
|
484
|
+
int idx = bs.read(nbit);
|
|
485
|
+
accu += LUT[idx];
|
|
486
|
+
LUT += (uint64_t)1 << nbit;
|
|
487
|
+
}
|
|
488
|
+
return accu;
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
} // anonymous namespace
|
|
492
|
+
|
|
493
|
+
template <>
|
|
494
|
+
float AdditiveQuantizer::
|
|
495
|
+
compute_1_distance_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
|
|
496
|
+
const uint8_t* codes,
|
|
497
|
+
const float* LUT) const {
|
|
498
|
+
BitstringReader bs(codes, code_size);
|
|
499
|
+
return accumulate_IPs(*this, bs, codes, LUT);
|
|
500
|
+
}
|
|
501
|
+
|
|
502
|
+
template <>
|
|
503
|
+
float AdditiveQuantizer::
|
|
504
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_LUT_nonorm>(
|
|
505
|
+
const uint8_t* codes,
|
|
506
|
+
const float* LUT) const {
|
|
507
|
+
BitstringReader bs(codes, code_size);
|
|
508
|
+
return -accumulate_IPs(*this, bs, codes, LUT);
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
template <>
|
|
512
|
+
float AdditiveQuantizer::
|
|
513
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_float>(
|
|
514
|
+
const uint8_t* codes,
|
|
515
|
+
const float* LUT) const {
|
|
516
|
+
BitstringReader bs(codes, code_size);
|
|
517
|
+
float accu = accumulate_IPs(*this, bs, codes, LUT);
|
|
518
|
+
uint32_t norm_i = bs.read(32);
|
|
519
|
+
float norm2;
|
|
520
|
+
memcpy(&norm2, &norm_i, 4);
|
|
521
|
+
return norm2 - 2 * accu;
|
|
522
|
+
}
|
|
523
|
+
|
|
524
|
+
template <>
|
|
525
|
+
float AdditiveQuantizer::
|
|
526
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint8>(
|
|
527
|
+
const uint8_t* codes,
|
|
528
|
+
const float* LUT) const {
|
|
529
|
+
BitstringReader bs(codes, code_size);
|
|
530
|
+
float accu = accumulate_IPs(*this, bs, codes, LUT);
|
|
531
|
+
uint32_t norm_i = bs.read(8);
|
|
532
|
+
float norm2 = decode_qcint(norm_i);
|
|
533
|
+
return norm2 - 2 * accu;
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
template <>
|
|
537
|
+
float AdditiveQuantizer::
|
|
538
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint4>(
|
|
539
|
+
const uint8_t* codes,
|
|
540
|
+
const float* LUT) const {
|
|
541
|
+
BitstringReader bs(codes, code_size);
|
|
542
|
+
float accu = accumulate_IPs(*this, bs, codes, LUT);
|
|
543
|
+
uint32_t norm_i = bs.read(4);
|
|
544
|
+
float norm2 = decode_qcint(norm_i);
|
|
545
|
+
return norm2 - 2 * accu;
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
template <>
|
|
549
|
+
float AdditiveQuantizer::
|
|
550
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint8>(
|
|
551
|
+
const uint8_t* codes,
|
|
552
|
+
const float* LUT) const {
|
|
553
|
+
BitstringReader bs(codes, code_size);
|
|
554
|
+
float accu = accumulate_IPs(*this, bs, codes, LUT);
|
|
555
|
+
uint32_t norm_i = bs.read(8);
|
|
556
|
+
float norm2 = decode_qint8(norm_i, norm_min, norm_max);
|
|
557
|
+
return norm2 - 2 * accu;
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
template <>
|
|
561
|
+
float AdditiveQuantizer::
|
|
562
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint4>(
|
|
563
|
+
const uint8_t* codes,
|
|
564
|
+
const float* LUT) const {
|
|
565
|
+
BitstringReader bs(codes, code_size);
|
|
566
|
+
float accu = accumulate_IPs(*this, bs, codes, LUT);
|
|
567
|
+
uint32_t norm_i = bs.read(4);
|
|
568
|
+
float norm2 = decode_qint4(norm_i, norm_min, norm_max);
|
|
569
|
+
return norm2 - 2 * accu;
|
|
570
|
+
}
|
|
571
|
+
|
|
270
572
|
} // namespace faiss
|
|
@@ -11,6 +11,8 @@
|
|
|
11
11
|
#include <vector>
|
|
12
12
|
|
|
13
13
|
#include <faiss/Index.h>
|
|
14
|
+
#include <faiss/IndexFlat.h>
|
|
15
|
+
#include <faiss/impl/Quantizer.h>
|
|
14
16
|
|
|
15
17
|
namespace faiss {
|
|
16
18
|
|
|
@@ -20,58 +22,140 @@ namespace faiss {
|
|
|
20
22
|
* concatenation of M sub-vectors, additive quantizers sum M sub-vectors
|
|
21
23
|
* to get the decoded vector.
|
|
22
24
|
*/
|
|
23
|
-
struct AdditiveQuantizer {
|
|
24
|
-
size_t d; ///< size of the input vectors
|
|
25
|
+
struct AdditiveQuantizer : Quantizer {
|
|
25
26
|
size_t M; ///< number of codebooks
|
|
26
27
|
std::vector<size_t> nbits; ///< bits for each step
|
|
27
28
|
std::vector<float> codebooks; ///< codebooks
|
|
28
29
|
|
|
29
30
|
// derived values
|
|
30
|
-
std::vector<
|
|
31
|
-
size_t
|
|
32
|
-
size_t
|
|
31
|
+
std::vector<uint64_t> codebook_offsets;
|
|
32
|
+
size_t tot_bits; ///< total number of bits (indexes + norms)
|
|
33
|
+
size_t norm_bits; ///< bits allocated for the norms
|
|
33
34
|
size_t total_codebook_size; ///< size of the codebook in vectors
|
|
34
|
-
bool
|
|
35
|
+
bool only_8bit; ///< are all nbits = 8 (use faster decoder)
|
|
35
36
|
|
|
36
37
|
bool verbose; ///< verbose during training?
|
|
37
38
|
bool is_trained; ///< is trained or not
|
|
38
39
|
|
|
40
|
+
IndexFlat1D qnorm; ///< store and search norms
|
|
41
|
+
std::vector<float> norm_tabs; ///< store norms of codebook entries for 4-bit
|
|
42
|
+
///< fastscan search
|
|
43
|
+
|
|
44
|
+
/// norms and distance matrixes with beam search can get large, so use this
|
|
45
|
+
/// to control for the amount of memory that can be allocated
|
|
46
|
+
size_t max_mem_distances;
|
|
47
|
+
|
|
48
|
+
/// encode a norm into norm_bits bits
|
|
49
|
+
uint64_t encode_norm(float norm) const;
|
|
50
|
+
|
|
51
|
+
uint32_t encode_qcint(
|
|
52
|
+
float x) const; ///< encode norm by non-uniform scalar quantization
|
|
53
|
+
|
|
54
|
+
float decode_qcint(uint32_t c)
|
|
55
|
+
const; ///< decode norm by non-uniform scalar quantization
|
|
56
|
+
|
|
57
|
+
/// Encodes how search is performed and how vectors are encoded
|
|
58
|
+
enum Search_type_t {
|
|
59
|
+
ST_decompress, ///< decompress database vector
|
|
60
|
+
ST_LUT_nonorm, ///< use a LUT, don't include norms (OK for IP or
|
|
61
|
+
///< normalized vectors)
|
|
62
|
+
ST_norm_from_LUT, ///< compute the norms from the look-up tables (cost
|
|
63
|
+
///< is in O(M^2))
|
|
64
|
+
ST_norm_float, ///< use a LUT, and store float32 norm with the vectors
|
|
65
|
+
ST_norm_qint8, ///< use a LUT, and store 8bit-quantized norm
|
|
66
|
+
ST_norm_qint4,
|
|
67
|
+
ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm
|
|
68
|
+
ST_norm_cqint4,
|
|
69
|
+
|
|
70
|
+
ST_norm_lsq2x4, ///< use a 2x4 bits lsq as norm quantizer (for fast
|
|
71
|
+
///< scan)
|
|
72
|
+
ST_norm_rq2x4, ///< use a 2x4 bits rq as norm quantizer (for fast scan)
|
|
73
|
+
};
|
|
74
|
+
|
|
75
|
+
AdditiveQuantizer(
|
|
76
|
+
size_t d,
|
|
77
|
+
const std::vector<size_t>& nbits,
|
|
78
|
+
Search_type_t search_type = ST_decompress);
|
|
79
|
+
|
|
80
|
+
AdditiveQuantizer();
|
|
81
|
+
|
|
39
82
|
///< compute derived values when d, M and nbits have been set
|
|
40
83
|
void set_derived_values();
|
|
41
84
|
|
|
42
|
-
///< Train the
|
|
43
|
-
|
|
85
|
+
///< Train the norm quantizer
|
|
86
|
+
void train_norm(size_t n, const float* norms);
|
|
87
|
+
|
|
88
|
+
void compute_codes(const float* x, uint8_t* codes, size_t n)
|
|
89
|
+
const override {
|
|
90
|
+
compute_codes_add_centroids(x, codes, n);
|
|
91
|
+
}
|
|
44
92
|
|
|
45
93
|
/** Encode a set of vectors
|
|
46
94
|
*
|
|
47
95
|
* @param x vectors to encode, size n * d
|
|
48
96
|
* @param codes output codes, size n * code_size
|
|
97
|
+
* @param centroids centroids to be added to x, size n * d
|
|
49
98
|
*/
|
|
50
|
-
virtual void
|
|
51
|
-
const
|
|
99
|
+
virtual void compute_codes_add_centroids(
|
|
100
|
+
const float* x,
|
|
101
|
+
uint8_t* codes,
|
|
102
|
+
size_t n,
|
|
103
|
+
const float* centroids = nullptr) const = 0;
|
|
52
104
|
|
|
53
105
|
/** pack a series of code to bit-compact format
|
|
54
106
|
*
|
|
55
|
-
* @param codes
|
|
107
|
+
* @param codes codes to be packed, size n * code_size
|
|
56
108
|
* @param packed_codes output bit-compact codes
|
|
57
|
-
* @param ld_codes
|
|
109
|
+
* @param ld_codes leading dimension of codes
|
|
110
|
+
* @param norms norms of the vectors (size n). Will be computed if
|
|
111
|
+
* needed but not provided
|
|
112
|
+
* @param centroids centroids to be added to x, size n * d
|
|
58
113
|
*/
|
|
59
114
|
void pack_codes(
|
|
60
115
|
size_t n,
|
|
61
116
|
const int32_t* codes,
|
|
62
117
|
uint8_t* packed_codes,
|
|
63
|
-
int64_t ld_codes = -1
|
|
118
|
+
int64_t ld_codes = -1,
|
|
119
|
+
const float* norms = nullptr,
|
|
120
|
+
const float* centroids = nullptr) const;
|
|
64
121
|
|
|
65
122
|
/** Decode a set of vectors
|
|
66
123
|
*
|
|
67
124
|
* @param codes codes to decode, size n * code_size
|
|
68
125
|
* @param x output vectors, size n * d
|
|
69
126
|
*/
|
|
70
|
-
void decode(const uint8_t* codes, float* x, size_t n) const;
|
|
127
|
+
void decode(const uint8_t* codes, float* x, size_t n) const override;
|
|
128
|
+
|
|
129
|
+
/** Decode a set of vectors in non-packed format
|
|
130
|
+
*
|
|
131
|
+
* @param codes codes to decode, size n * ld_codes
|
|
132
|
+
* @param x output vectors, size n * d
|
|
133
|
+
*/
|
|
134
|
+
virtual void decode_unpacked(
|
|
135
|
+
const int32_t* codes,
|
|
136
|
+
float* x,
|
|
137
|
+
size_t n,
|
|
138
|
+
int64_t ld_codes = -1) const;
|
|
71
139
|
|
|
72
140
|
/****************************************************************************
|
|
73
|
-
*
|
|
74
|
-
|
|
141
|
+
* Search functions in an external set of codes.
|
|
142
|
+
****************************************************************************/
|
|
143
|
+
|
|
144
|
+
/// Also determines what's in the codes
|
|
145
|
+
Search_type_t search_type;
|
|
146
|
+
|
|
147
|
+
/// min/max for quantization of norms
|
|
148
|
+
float norm_min, norm_max;
|
|
149
|
+
|
|
150
|
+
template <bool is_IP, Search_type_t effective_search_type>
|
|
151
|
+
float compute_1_distance_LUT(const uint8_t* codes, const float* LUT) const;
|
|
152
|
+
|
|
153
|
+
/*
|
|
154
|
+
float compute_1_L2sqr(const uint8_t* codes, const float* LUT);
|
|
155
|
+
*/
|
|
156
|
+
/****************************************************************************
|
|
157
|
+
* Support for exhaustive distance computations with all the centroids.
|
|
158
|
+
* Hence, the number of these centroids should not be too large.
|
|
75
159
|
****************************************************************************/
|
|
76
160
|
using idx_t = Index::idx_t;
|
|
77
161
|
|
|
@@ -83,11 +167,18 @@ struct AdditiveQuantizer {
|
|
|
83
167
|
*
|
|
84
168
|
* @param xq query vector, size (n, d)
|
|
85
169
|
* @param LUT look-up table, size (n, total_codebook_size)
|
|
170
|
+
* @param alpha compute alpha * inner-product
|
|
171
|
+
* @param ld_lut leading dimension of LUT
|
|
86
172
|
*/
|
|
87
|
-
void compute_LUT(
|
|
173
|
+
virtual void compute_LUT(
|
|
174
|
+
size_t n,
|
|
175
|
+
const float* xq,
|
|
176
|
+
float* LUT,
|
|
177
|
+
float alpha = 1.0f,
|
|
178
|
+
long ld_lut = -1) const;
|
|
88
179
|
|
|
89
180
|
/// exact IP search
|
|
90
|
-
void
|
|
181
|
+
void knn_centroids_inner_product(
|
|
91
182
|
idx_t n,
|
|
92
183
|
const float* xq,
|
|
93
184
|
idx_t k,
|
|
@@ -101,7 +192,7 @@ struct AdditiveQuantizer {
|
|
|
101
192
|
void compute_centroid_norms(float* norms) const;
|
|
102
193
|
|
|
103
194
|
/** Exact L2 search, with precomputed norms */
|
|
104
|
-
void
|
|
195
|
+
void knn_centroids_L2(
|
|
105
196
|
idx_t n,
|
|
106
197
|
const float* xq,
|
|
107
198
|
idx_t k,
|