faiss 0.2.3 → 0.2.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/LICENSE.txt +1 -1
- data/README.md +23 -21
- data/ext/faiss/extconf.rb +11 -0
- data/ext/faiss/index.cpp +4 -4
- data/ext/faiss/index_binary.cpp +6 -6
- data/ext/faiss/product_quantizer.cpp +4 -4
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/AutoTune.cpp +13 -0
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/IVFlib.cpp +101 -2
- data/vendor/faiss/faiss/IVFlib.h +26 -2
- data/vendor/faiss/faiss/Index.cpp +36 -3
- data/vendor/faiss/faiss/Index.h +43 -6
- data/vendor/faiss/faiss/Index2Layer.cpp +24 -93
- data/vendor/faiss/faiss/Index2Layer.h +8 -17
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +610 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.h +253 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.cpp +299 -0
- data/vendor/faiss/faiss/IndexAdditiveQuantizerFastScan.h +199 -0
- data/vendor/faiss/faiss/IndexBinary.cpp +20 -4
- data/vendor/faiss/faiss/IndexBinary.h +18 -3
- data/vendor/faiss/faiss/IndexBinaryFlat.cpp +9 -2
- data/vendor/faiss/faiss/IndexBinaryFlat.h +4 -2
- data/vendor/faiss/faiss/IndexBinaryFromFloat.cpp +4 -1
- data/vendor/faiss/faiss/IndexBinaryFromFloat.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.cpp +5 -1
- data/vendor/faiss/faiss/IndexBinaryHNSW.h +2 -1
- data/vendor/faiss/faiss/IndexBinaryHash.cpp +17 -4
- data/vendor/faiss/faiss/IndexBinaryHash.h +8 -4
- data/vendor/faiss/faiss/IndexBinaryIVF.cpp +28 -13
- data/vendor/faiss/faiss/IndexBinaryIVF.h +10 -7
- data/vendor/faiss/faiss/IndexFastScan.cpp +626 -0
- data/vendor/faiss/faiss/IndexFastScan.h +145 -0
- data/vendor/faiss/faiss/IndexFlat.cpp +52 -69
- data/vendor/faiss/faiss/IndexFlat.h +16 -19
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +101 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +59 -0
- data/vendor/faiss/faiss/IndexHNSW.cpp +66 -138
- data/vendor/faiss/faiss/IndexHNSW.h +4 -2
- data/vendor/faiss/faiss/IndexIDMap.cpp +247 -0
- data/vendor/faiss/faiss/IndexIDMap.h +107 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +200 -40
- data/vendor/faiss/faiss/IndexIVF.h +59 -22
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +393 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +183 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.cpp +590 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizerFastScan.h +171 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.cpp +1290 -0
- data/vendor/faiss/faiss/IndexIVFFastScan.h +213 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +43 -26
- data/vendor/faiss/faiss/IndexIVFFlat.h +4 -2
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +238 -53
- data/vendor/faiss/faiss/IndexIVFPQ.h +6 -2
- data/vendor/faiss/faiss/IndexIVFPQFastScan.cpp +23 -852
- data/vendor/faiss/faiss/IndexIVFPQFastScan.h +7 -112
- data/vendor/faiss/faiss/IndexIVFPQR.cpp +3 -3
- data/vendor/faiss/faiss/IndexIVFPQR.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +63 -40
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +23 -7
- data/vendor/faiss/faiss/IndexLSH.cpp +8 -32
- data/vendor/faiss/faiss/IndexLSH.h +4 -16
- data/vendor/faiss/faiss/IndexLattice.cpp +7 -1
- data/vendor/faiss/faiss/IndexLattice.h +3 -1
- data/vendor/faiss/faiss/IndexNNDescent.cpp +4 -5
- data/vendor/faiss/faiss/IndexNNDescent.h +2 -1
- data/vendor/faiss/faiss/IndexNSG.cpp +37 -5
- data/vendor/faiss/faiss/IndexNSG.h +25 -1
- data/vendor/faiss/faiss/IndexPQ.cpp +108 -120
- data/vendor/faiss/faiss/IndexPQ.h +21 -22
- data/vendor/faiss/faiss/IndexPQFastScan.cpp +15 -450
- data/vendor/faiss/faiss/IndexPQFastScan.h +15 -78
- data/vendor/faiss/faiss/IndexPreTransform.cpp +47 -8
- data/vendor/faiss/faiss/IndexPreTransform.h +15 -3
- data/vendor/faiss/faiss/IndexRefine.cpp +36 -4
- data/vendor/faiss/faiss/IndexRefine.h +14 -2
- data/vendor/faiss/faiss/IndexReplicas.cpp +4 -2
- data/vendor/faiss/faiss/IndexReplicas.h +2 -1
- data/vendor/faiss/faiss/IndexRowwiseMinMax.cpp +438 -0
- data/vendor/faiss/faiss/IndexRowwiseMinMax.h +92 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +28 -43
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +8 -23
- data/vendor/faiss/faiss/IndexShards.cpp +4 -1
- data/vendor/faiss/faiss/IndexShards.h +2 -1
- data/vendor/faiss/faiss/MetaIndexes.cpp +5 -178
- data/vendor/faiss/faiss/MetaIndexes.h +3 -81
- data/vendor/faiss/faiss/VectorTransform.cpp +45 -1
- data/vendor/faiss/faiss/VectorTransform.h +25 -4
- data/vendor/faiss/faiss/clone_index.cpp +26 -3
- data/vendor/faiss/faiss/clone_index.h +3 -0
- data/vendor/faiss/faiss/cppcontrib/SaDecodeKernels.h +300 -0
- data/vendor/faiss/faiss/cppcontrib/detail/CoarseBitType.h +24 -0
- data/vendor/faiss/faiss/cppcontrib/detail/UintReader.h +195 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-avx2-inl.h +2058 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-inl.h +408 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/Level2-neon-inl.h +2147 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMax-inl.h +460 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/MinMaxFP16-inl.h +465 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-avx2-inl.h +1618 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-inl.h +251 -0
- data/vendor/faiss/faiss/cppcontrib/sa_decode/PQ-neon-inl.h +1452 -0
- data/vendor/faiss/faiss/gpu/GpuAutoTune.cpp +1 -0
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -6
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/gpu/GpuIndex.h +28 -4
- data/vendor/faiss/faiss/gpu/GpuIndexBinaryFlat.h +2 -1
- data/vendor/faiss/faiss/gpu/GpuIndexFlat.h +10 -8
- data/vendor/faiss/faiss/gpu/GpuIndexIVF.h +75 -14
- data/vendor/faiss/faiss/gpu/GpuIndexIVFFlat.h +19 -32
- data/vendor/faiss/faiss/gpu/GpuIndexIVFPQ.h +22 -31
- data/vendor/faiss/faiss/gpu/GpuIndexIVFScalarQuantizer.h +22 -28
- data/vendor/faiss/faiss/gpu/GpuResources.cpp +14 -0
- data/vendor/faiss/faiss/gpu/GpuResources.h +16 -3
- data/vendor/faiss/faiss/gpu/StandardGpuResources.cpp +3 -3
- data/vendor/faiss/faiss/gpu/impl/IndexUtils.h +32 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexBinaryFlat.cpp +1 -0
- data/vendor/faiss/faiss/gpu/test/TestGpuIndexFlat.cpp +311 -75
- data/vendor/faiss/faiss/gpu/test/TestUtils.cpp +10 -0
- data/vendor/faiss/faiss/gpu/test/TestUtils.h +3 -0
- data/vendor/faiss/faiss/gpu/test/demo_ivfpq_indexing_gpu.cpp +2 -2
- data/vendor/faiss/faiss/gpu/utils/DeviceUtils.h +5 -4
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +331 -29
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +110 -19
- data/vendor/faiss/faiss/impl/AuxIndexStructures.cpp +0 -54
- data/vendor/faiss/faiss/impl/AuxIndexStructures.h +0 -76
- data/vendor/faiss/faiss/impl/DistanceComputer.h +64 -0
- data/vendor/faiss/faiss/impl/HNSW.cpp +133 -32
- data/vendor/faiss/faiss/impl/HNSW.h +19 -16
- data/vendor/faiss/faiss/impl/IDSelector.cpp +125 -0
- data/vendor/faiss/faiss/impl/IDSelector.h +135 -0
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +378 -217
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +106 -29
- data/vendor/faiss/faiss/impl/LookupTableScaler.h +77 -0
- data/vendor/faiss/faiss/impl/NNDescent.cpp +1 -0
- data/vendor/faiss/faiss/impl/NSG.cpp +1 -4
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.cpp +383 -0
- data/vendor/faiss/faiss/impl/ProductAdditiveQuantizer.h +154 -0
- data/vendor/faiss/faiss/impl/ProductQuantizer.cpp +225 -145
- data/vendor/faiss/faiss/impl/ProductQuantizer.h +29 -10
- data/vendor/faiss/faiss/impl/Quantizer.h +43 -0
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +521 -55
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +94 -16
- data/vendor/faiss/faiss/impl/ResultHandler.h +96 -0
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +108 -191
- data/vendor/faiss/faiss/impl/ScalarQuantizer.h +18 -18
- data/vendor/faiss/faiss/impl/index_read.cpp +338 -24
- data/vendor/faiss/faiss/impl/index_write.cpp +300 -18
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +303 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/impl/pq4_fast_scan.cpp +56 -16
- data/vendor/faiss/faiss/impl/pq4_fast_scan.h +25 -8
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_1.cpp +66 -25
- data/vendor/faiss/faiss/impl/pq4_fast_scan_search_qbs.cpp +75 -27
- data/vendor/faiss/faiss/index_factory.cpp +772 -412
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/index_io.h +5 -0
- data/vendor/faiss/faiss/invlists/DirectMap.cpp +1 -0
- data/vendor/faiss/faiss/invlists/InvertedLists.cpp +4 -1
- data/vendor/faiss/faiss/invlists/OnDiskInvertedLists.cpp +2 -1
- data/vendor/faiss/faiss/python/python_callbacks.cpp +27 -0
- data/vendor/faiss/faiss/python/python_callbacks.h +15 -0
- data/vendor/faiss/faiss/utils/Heap.h +31 -15
- data/vendor/faiss/faiss/utils/distances.cpp +384 -58
- data/vendor/faiss/faiss/utils/distances.h +149 -18
- data/vendor/faiss/faiss/utils/distances_simd.cpp +776 -6
- data/vendor/faiss/faiss/utils/extra_distances.cpp +12 -7
- data/vendor/faiss/faiss/utils/extra_distances.h +3 -1
- data/vendor/faiss/faiss/utils/fp16-fp16c.h +21 -0
- data/vendor/faiss/faiss/utils/fp16-inl.h +101 -0
- data/vendor/faiss/faiss/utils/fp16.h +11 -0
- data/vendor/faiss/faiss/utils/hamming-inl.h +54 -0
- data/vendor/faiss/faiss/utils/hamming.cpp +0 -48
- data/vendor/faiss/faiss/utils/ordered_key_value.h +10 -0
- data/vendor/faiss/faiss/utils/quantize_lut.cpp +62 -0
- data/vendor/faiss/faiss/utils/quantize_lut.h +20 -0
- data/vendor/faiss/faiss/utils/random.cpp +53 -0
- data/vendor/faiss/faiss/utils/random.h +5 -0
- data/vendor/faiss/faiss/utils/simdlib_avx2.h +4 -0
- data/vendor/faiss/faiss/utils/simdlib_emulated.h +6 -1
- data/vendor/faiss/faiss/utils/simdlib_neon.h +7 -2
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +46 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
- data/vendor/faiss/faiss/IndexResidual.h +0 -152
|
@@ -0,0 +1,438 @@
|
|
|
1
|
+
#include <faiss/IndexRowwiseMinMax.h>
|
|
2
|
+
|
|
3
|
+
#include <cstdint>
|
|
4
|
+
#include <cstring>
|
|
5
|
+
#include <limits>
|
|
6
|
+
|
|
7
|
+
#include <faiss/impl/FaissAssert.h>
|
|
8
|
+
#include <faiss/utils/fp16.h>
|
|
9
|
+
|
|
10
|
+
namespace faiss {
|
|
11
|
+
|
|
12
|
+
namespace {
|
|
13
|
+
|
|
14
|
+
using idx_t = faiss::Index::idx_t;
|
|
15
|
+
|
|
16
|
+
struct StorageMinMaxFP16 {
|
|
17
|
+
uint16_t scaler;
|
|
18
|
+
uint16_t minv;
|
|
19
|
+
|
|
20
|
+
inline void from_floats(const float float_scaler, const float float_minv) {
|
|
21
|
+
scaler = encode_fp16(float_scaler);
|
|
22
|
+
minv = encode_fp16(float_minv);
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
inline void to_floats(float& float_scaler, float& float_minv) const {
|
|
26
|
+
float_scaler = decode_fp16(scaler);
|
|
27
|
+
float_minv = decode_fp16(minv);
|
|
28
|
+
}
|
|
29
|
+
};
|
|
30
|
+
|
|
31
|
+
struct StorageMinMaxFP32 {
|
|
32
|
+
float scaler;
|
|
33
|
+
float minv;
|
|
34
|
+
|
|
35
|
+
inline void from_floats(const float float_scaler, const float float_minv) {
|
|
36
|
+
scaler = float_scaler;
|
|
37
|
+
minv = float_minv;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
inline void to_floats(float& float_scaler, float& float_minv) const {
|
|
41
|
+
float_scaler = scaler;
|
|
42
|
+
float_minv = minv;
|
|
43
|
+
}
|
|
44
|
+
};
|
|
45
|
+
|
|
46
|
+
template <typename StorageMinMaxT>
|
|
47
|
+
void sa_encode_impl(
|
|
48
|
+
const IndexRowwiseMinMaxBase* const index,
|
|
49
|
+
const idx_t n_input,
|
|
50
|
+
const float* x_input,
|
|
51
|
+
uint8_t* bytes_output) {
|
|
52
|
+
// process chunks
|
|
53
|
+
const size_t chunk_size = rowwise_minmax_sa_encode_bs;
|
|
54
|
+
|
|
55
|
+
// useful variables
|
|
56
|
+
const Index* const sub_index = index->index;
|
|
57
|
+
const int d = index->d;
|
|
58
|
+
|
|
59
|
+
// the code size of the subindex
|
|
60
|
+
const size_t old_code_size = sub_index->sa_code_size();
|
|
61
|
+
// the code size of the index
|
|
62
|
+
const size_t new_code_size = index->sa_code_size();
|
|
63
|
+
|
|
64
|
+
// allocate tmp buffers
|
|
65
|
+
std::vector<float> tmp(chunk_size * d);
|
|
66
|
+
std::vector<StorageMinMaxT> minmax(chunk_size);
|
|
67
|
+
|
|
68
|
+
// all the elements to process
|
|
69
|
+
size_t n_left = n_input;
|
|
70
|
+
|
|
71
|
+
const float* __restrict x = x_input;
|
|
72
|
+
uint8_t* __restrict bytes = bytes_output;
|
|
73
|
+
|
|
74
|
+
while (n_left > 0) {
|
|
75
|
+
// current portion to be processed
|
|
76
|
+
const idx_t n = std::min(n_left, chunk_size);
|
|
77
|
+
|
|
78
|
+
// allocate a temporary buffer and do the rescale
|
|
79
|
+
for (idx_t i = 0; i < n; i++) {
|
|
80
|
+
// compute min & max values
|
|
81
|
+
float minv = std::numeric_limits<float>::max();
|
|
82
|
+
float maxv = std::numeric_limits<float>::lowest();
|
|
83
|
+
|
|
84
|
+
const float* const vec_in = x + i * d;
|
|
85
|
+
for (idx_t j = 0; j < d; j++) {
|
|
86
|
+
minv = std::min(minv, vec_in[j]);
|
|
87
|
+
maxv = std::max(maxv, vec_in[j]);
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
// save the coefficients
|
|
91
|
+
const float scaler = maxv - minv;
|
|
92
|
+
minmax[i].from_floats(scaler, minv);
|
|
93
|
+
|
|
94
|
+
// and load them back, because the coefficients might
|
|
95
|
+
// be modified.
|
|
96
|
+
float actual_scaler = 0;
|
|
97
|
+
float actual_minv = 0;
|
|
98
|
+
minmax[i].to_floats(actual_scaler, actual_minv);
|
|
99
|
+
|
|
100
|
+
float* const vec_out = tmp.data() + i * d;
|
|
101
|
+
if (actual_scaler == 0) {
|
|
102
|
+
for (idx_t j = 0; j < d; j++) {
|
|
103
|
+
vec_out[j] = 0;
|
|
104
|
+
}
|
|
105
|
+
} else {
|
|
106
|
+
float inv_actual_scaler = 1.0f / actual_scaler;
|
|
107
|
+
for (idx_t j = 0; j < d; j++) {
|
|
108
|
+
vec_out[j] = (vec_in[j] - actual_minv) * inv_actual_scaler;
|
|
109
|
+
}
|
|
110
|
+
}
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
// do the coding
|
|
114
|
+
sub_index->sa_encode(n, tmp.data(), bytes);
|
|
115
|
+
|
|
116
|
+
// rearrange
|
|
117
|
+
for (idx_t i = n; (i--) > 0;) {
|
|
118
|
+
// move a single index
|
|
119
|
+
std::memmove(
|
|
120
|
+
bytes + i * new_code_size + (new_code_size - old_code_size),
|
|
121
|
+
bytes + i * old_code_size,
|
|
122
|
+
old_code_size);
|
|
123
|
+
|
|
124
|
+
// save min & max values
|
|
125
|
+
StorageMinMaxT* fpv = reinterpret_cast<StorageMinMaxT*>(
|
|
126
|
+
bytes + i * new_code_size);
|
|
127
|
+
*fpv = minmax[i];
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
// next chunk
|
|
131
|
+
x += n * d;
|
|
132
|
+
bytes += n * new_code_size;
|
|
133
|
+
|
|
134
|
+
n_left -= n;
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
template <typename StorageMinMaxT>
|
|
139
|
+
void sa_decode_impl(
|
|
140
|
+
const IndexRowwiseMinMaxBase* const index,
|
|
141
|
+
const idx_t n_input,
|
|
142
|
+
const uint8_t* bytes_input,
|
|
143
|
+
float* x_output) {
|
|
144
|
+
// process chunks
|
|
145
|
+
const size_t chunk_size = rowwise_minmax_sa_decode_bs;
|
|
146
|
+
|
|
147
|
+
// useful variables
|
|
148
|
+
const Index* const sub_index = index->index;
|
|
149
|
+
const int d = index->d;
|
|
150
|
+
|
|
151
|
+
// the code size of the subindex
|
|
152
|
+
const size_t old_code_size = sub_index->sa_code_size();
|
|
153
|
+
// the code size of the index
|
|
154
|
+
const size_t new_code_size = index->sa_code_size();
|
|
155
|
+
|
|
156
|
+
// allocate tmp buffers
|
|
157
|
+
std::vector<uint8_t> tmp(
|
|
158
|
+
(chunk_size < n_input ? chunk_size : n_input) * old_code_size);
|
|
159
|
+
std::vector<StorageMinMaxFP16> minmax(
|
|
160
|
+
(chunk_size < n_input ? chunk_size : n_input));
|
|
161
|
+
|
|
162
|
+
// all the elements to process
|
|
163
|
+
size_t n_left = n_input;
|
|
164
|
+
|
|
165
|
+
const uint8_t* __restrict bytes = bytes_input;
|
|
166
|
+
float* __restrict x = x_output;
|
|
167
|
+
|
|
168
|
+
while (n_left > 0) {
|
|
169
|
+
// current portion to be processed
|
|
170
|
+
const idx_t n = std::min(n_left, chunk_size);
|
|
171
|
+
|
|
172
|
+
// rearrange
|
|
173
|
+
for (idx_t i = 0; i < n; i++) {
|
|
174
|
+
std::memcpy(
|
|
175
|
+
tmp.data() + i * old_code_size,
|
|
176
|
+
bytes + i * new_code_size + (new_code_size - old_code_size),
|
|
177
|
+
old_code_size);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
// decode
|
|
181
|
+
sub_index->sa_decode(n, tmp.data(), x);
|
|
182
|
+
|
|
183
|
+
// scale back
|
|
184
|
+
for (idx_t i = 0; i < n; i++) {
|
|
185
|
+
const uint8_t* const vec_in = bytes + i * new_code_size;
|
|
186
|
+
StorageMinMaxT fpv =
|
|
187
|
+
*(reinterpret_cast<const StorageMinMaxT*>(vec_in));
|
|
188
|
+
|
|
189
|
+
float scaler = 0;
|
|
190
|
+
float minv = 0;
|
|
191
|
+
fpv.to_floats(scaler, minv);
|
|
192
|
+
|
|
193
|
+
float* const __restrict vec = x + d * i;
|
|
194
|
+
|
|
195
|
+
for (idx_t j = 0; j < d; j++) {
|
|
196
|
+
vec[j] = vec[j] * scaler + minv;
|
|
197
|
+
}
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
// next chunk
|
|
201
|
+
bytes += n * new_code_size;
|
|
202
|
+
x += n * d;
|
|
203
|
+
|
|
204
|
+
n_left -= n;
|
|
205
|
+
}
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
//
|
|
209
|
+
template <typename StorageMinMaxT>
|
|
210
|
+
void train_inplace_impl(
|
|
211
|
+
IndexRowwiseMinMaxBase* const index,
|
|
212
|
+
idx_t n,
|
|
213
|
+
float* x) {
|
|
214
|
+
// useful variables
|
|
215
|
+
Index* const sub_index = index->index;
|
|
216
|
+
const int d = index->d;
|
|
217
|
+
|
|
218
|
+
// save normalizing coefficients
|
|
219
|
+
std::vector<StorageMinMaxT> minmax(n);
|
|
220
|
+
|
|
221
|
+
// normalize
|
|
222
|
+
#pragma omp for
|
|
223
|
+
for (idx_t i = 0; i < n; i++) {
|
|
224
|
+
// compute min & max values
|
|
225
|
+
float minv = std::numeric_limits<float>::max();
|
|
226
|
+
float maxv = std::numeric_limits<float>::lowest();
|
|
227
|
+
|
|
228
|
+
float* const vec = x + i * d;
|
|
229
|
+
for (idx_t j = 0; j < d; j++) {
|
|
230
|
+
minv = std::min(minv, vec[j]);
|
|
231
|
+
maxv = std::max(maxv, vec[j]);
|
|
232
|
+
}
|
|
233
|
+
|
|
234
|
+
// save the coefficients
|
|
235
|
+
const float scaler = maxv - minv;
|
|
236
|
+
minmax[i].from_floats(scaler, minv);
|
|
237
|
+
|
|
238
|
+
// and load them back, because the coefficients might
|
|
239
|
+
// be modified.
|
|
240
|
+
float actual_scaler = 0;
|
|
241
|
+
float actual_minv = 0;
|
|
242
|
+
minmax[i].to_floats(actual_scaler, actual_minv);
|
|
243
|
+
|
|
244
|
+
if (actual_scaler == 0) {
|
|
245
|
+
for (idx_t j = 0; j < d; j++) {
|
|
246
|
+
vec[j] = 0;
|
|
247
|
+
}
|
|
248
|
+
} else {
|
|
249
|
+
float inv_actual_scaler = 1.0f / actual_scaler;
|
|
250
|
+
for (idx_t j = 0; j < d; j++) {
|
|
251
|
+
vec[j] = (vec[j] - actual_minv) * inv_actual_scaler;
|
|
252
|
+
}
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
// train the subindex
|
|
257
|
+
sub_index->train(n, x);
|
|
258
|
+
|
|
259
|
+
// rescale data back
|
|
260
|
+
for (idx_t i = 0; i < n; i++) {
|
|
261
|
+
float scaler = 0;
|
|
262
|
+
float minv = 0;
|
|
263
|
+
minmax[i].to_floats(scaler, minv);
|
|
264
|
+
|
|
265
|
+
float* const vec = x + i * d;
|
|
266
|
+
|
|
267
|
+
for (idx_t j = 0; j < d; j++) {
|
|
268
|
+
vec[j] = vec[j] * scaler + minv;
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
//
|
|
274
|
+
template <typename StorageMinMaxT>
|
|
275
|
+
void train_impl(IndexRowwiseMinMaxBase* const index, idx_t n, const float* x) {
|
|
276
|
+
// the default training that creates a copy of the input data
|
|
277
|
+
|
|
278
|
+
// useful variables
|
|
279
|
+
Index* const sub_index = index->index;
|
|
280
|
+
const int d = index->d;
|
|
281
|
+
|
|
282
|
+
// temp buffer
|
|
283
|
+
std::vector<float> tmp(n * d);
|
|
284
|
+
|
|
285
|
+
#pragma omp for
|
|
286
|
+
for (idx_t i = 0; i < n; i++) {
|
|
287
|
+
// compute min & max values
|
|
288
|
+
float minv = std::numeric_limits<float>::max();
|
|
289
|
+
float maxv = std::numeric_limits<float>::lowest();
|
|
290
|
+
|
|
291
|
+
const float* const __restrict vec_in = x + i * d;
|
|
292
|
+
for (idx_t j = 0; j < d; j++) {
|
|
293
|
+
minv = std::min(minv, vec_in[j]);
|
|
294
|
+
maxv = std::max(maxv, vec_in[j]);
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
const float scaler = maxv - minv;
|
|
298
|
+
|
|
299
|
+
// save the coefficients
|
|
300
|
+
StorageMinMaxT storage;
|
|
301
|
+
storage.from_floats(scaler, minv);
|
|
302
|
+
|
|
303
|
+
// and load them back, because the coefficients might
|
|
304
|
+
// be modified.
|
|
305
|
+
float actual_scaler = 0;
|
|
306
|
+
float actual_minv = 0;
|
|
307
|
+
storage.to_floats(actual_scaler, actual_minv);
|
|
308
|
+
|
|
309
|
+
float* const __restrict vec_out = tmp.data() + i * d;
|
|
310
|
+
if (actual_scaler == 0) {
|
|
311
|
+
for (idx_t j = 0; j < d; j++) {
|
|
312
|
+
vec_out[j] = 0;
|
|
313
|
+
}
|
|
314
|
+
} else {
|
|
315
|
+
float inv_actual_scaler = 1.0f / actual_scaler;
|
|
316
|
+
for (idx_t j = 0; j < d; j++) {
|
|
317
|
+
vec_out[j] = (vec_in[j] - actual_minv) * inv_actual_scaler;
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
sub_index->train(n, tmp.data());
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
} // namespace
|
|
326
|
+
|
|
327
|
+
// block size for performing sa_encode and sa_decode
|
|
328
|
+
int rowwise_minmax_sa_encode_bs = 16384;
|
|
329
|
+
int rowwise_minmax_sa_decode_bs = 16384;
|
|
330
|
+
|
|
331
|
+
/*********************************************************
|
|
332
|
+
* IndexRowwiseMinMaxBase implementation
|
|
333
|
+
********************************************************/
|
|
334
|
+
|
|
335
|
+
IndexRowwiseMinMaxBase::IndexRowwiseMinMaxBase(Index* index)
|
|
336
|
+
: Index(index->d, index->metric_type),
|
|
337
|
+
index{index},
|
|
338
|
+
own_fields{false} {}
|
|
339
|
+
|
|
340
|
+
IndexRowwiseMinMaxBase::IndexRowwiseMinMaxBase()
|
|
341
|
+
: index{nullptr}, own_fields{false} {}
|
|
342
|
+
|
|
343
|
+
IndexRowwiseMinMaxBase::~IndexRowwiseMinMaxBase() {
|
|
344
|
+
if (own_fields) {
|
|
345
|
+
delete index;
|
|
346
|
+
index = nullptr;
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
void IndexRowwiseMinMaxBase::add(idx_t, const float*) {
|
|
351
|
+
FAISS_THROW_MSG("add not implemented for this type of index");
|
|
352
|
+
}
|
|
353
|
+
|
|
354
|
+
void IndexRowwiseMinMaxBase::search(
|
|
355
|
+
idx_t,
|
|
356
|
+
const float*,
|
|
357
|
+
idx_t,
|
|
358
|
+
float*,
|
|
359
|
+
idx_t*,
|
|
360
|
+
const SearchParameters*) const {
|
|
361
|
+
FAISS_THROW_MSG("search not implemented for this type of index");
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
void IndexRowwiseMinMaxBase::reset() {
|
|
365
|
+
FAISS_THROW_MSG("reset not implemented for this type of index");
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
/*********************************************************
|
|
369
|
+
* IndexRowwiseMinMaxFP16 implementation
|
|
370
|
+
********************************************************/
|
|
371
|
+
|
|
372
|
+
IndexRowwiseMinMaxFP16::IndexRowwiseMinMaxFP16(Index* index)
|
|
373
|
+
: IndexRowwiseMinMaxBase(index) {}
|
|
374
|
+
|
|
375
|
+
IndexRowwiseMinMaxFP16::IndexRowwiseMinMaxFP16() : IndexRowwiseMinMaxBase() {}
|
|
376
|
+
|
|
377
|
+
size_t IndexRowwiseMinMaxFP16::sa_code_size() const {
|
|
378
|
+
return index->sa_code_size() + 2 * sizeof(uint16_t);
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
void IndexRowwiseMinMaxFP16::sa_encode(
|
|
382
|
+
idx_t n_input,
|
|
383
|
+
const float* x_input,
|
|
384
|
+
uint8_t* bytes_output) const {
|
|
385
|
+
sa_encode_impl<StorageMinMaxFP16>(this, n_input, x_input, bytes_output);
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
void IndexRowwiseMinMaxFP16::sa_decode(
|
|
389
|
+
idx_t n_input,
|
|
390
|
+
const uint8_t* bytes_input,
|
|
391
|
+
float* x_output) const {
|
|
392
|
+
sa_decode_impl<StorageMinMaxFP16>(this, n_input, bytes_input, x_output);
|
|
393
|
+
}
|
|
394
|
+
|
|
395
|
+
void IndexRowwiseMinMaxFP16::train(idx_t n, const float* x) {
|
|
396
|
+
train_impl<StorageMinMaxFP16>(this, n, x);
|
|
397
|
+
}
|
|
398
|
+
|
|
399
|
+
void IndexRowwiseMinMaxFP16::train_inplace(idx_t n, float* x) {
|
|
400
|
+
train_inplace_impl<StorageMinMaxFP16>(this, n, x);
|
|
401
|
+
}
|
|
402
|
+
|
|
403
|
+
/*********************************************************
|
|
404
|
+
* IndexRowwiseMinMax implementation
|
|
405
|
+
********************************************************/
|
|
406
|
+
|
|
407
|
+
IndexRowwiseMinMax::IndexRowwiseMinMax(Index* index)
|
|
408
|
+
: IndexRowwiseMinMaxBase(index) {}
|
|
409
|
+
|
|
410
|
+
IndexRowwiseMinMax::IndexRowwiseMinMax() : IndexRowwiseMinMaxBase() {}
|
|
411
|
+
|
|
412
|
+
size_t IndexRowwiseMinMax::sa_code_size() const {
|
|
413
|
+
return index->sa_code_size() + 2 * sizeof(float);
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
void IndexRowwiseMinMax::sa_encode(
|
|
417
|
+
idx_t n_input,
|
|
418
|
+
const float* x_input,
|
|
419
|
+
uint8_t* bytes_output) const {
|
|
420
|
+
sa_encode_impl<StorageMinMaxFP32>(this, n_input, x_input, bytes_output);
|
|
421
|
+
}
|
|
422
|
+
|
|
423
|
+
void IndexRowwiseMinMax::sa_decode(
|
|
424
|
+
idx_t n_input,
|
|
425
|
+
const uint8_t* bytes_input,
|
|
426
|
+
float* x_output) const {
|
|
427
|
+
sa_decode_impl<StorageMinMaxFP32>(this, n_input, bytes_input, x_output);
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
void IndexRowwiseMinMax::train(idx_t n, const float* x) {
|
|
431
|
+
train_impl<StorageMinMaxFP32>(this, n, x);
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
void IndexRowwiseMinMax::train_inplace(idx_t n, float* x) {
|
|
435
|
+
train_inplace_impl<StorageMinMaxFP32>(this, n, x);
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
} // namespace faiss
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include <cstdint>
|
|
4
|
+
#include <vector>
|
|
5
|
+
|
|
6
|
+
#include <faiss/Index.h>
|
|
7
|
+
#include <faiss/impl/platform_macros.h>
|
|
8
|
+
|
|
9
|
+
namespace faiss {
|
|
10
|
+
|
|
11
|
+
/// Index wrapper that performs rowwise normalization to [0,1], preserving
|
|
12
|
+
/// the coefficients. This is a vector codec index only.
|
|
13
|
+
///
|
|
14
|
+
/// Basically, this index performs a rowwise scaling to [0,1] of every row
|
|
15
|
+
/// in an input dataset before calling subindex::train() and
|
|
16
|
+
/// subindex::sa_encode(). sa_encode() call stores the scaling coefficients
|
|
17
|
+
/// (scaler and minv) in the very beginning of every output code. The format:
|
|
18
|
+
/// [scaler][minv][subindex::sa_encode() output]
|
|
19
|
+
/// The de-scaling in sa_decode() is done using:
|
|
20
|
+
/// output_rescaled = scaler * output + minv
|
|
21
|
+
///
|
|
22
|
+
/// An additional ::train_inplace() function is provided in order to do
|
|
23
|
+
/// an inplace scaling before calling subindex::train() and, thus, avoiding
|
|
24
|
+
/// the cloning of the input dataset, but modifying the input dataset because
|
|
25
|
+
/// of the scaling and the scaling back. It is up to user to call
|
|
26
|
+
/// this function instead of ::train()
|
|
27
|
+
///
|
|
28
|
+
/// Derived classes provide different data types for scaling coefficients.
|
|
29
|
+
/// Currently, versions with fp16 and fp32 scaling coefficients are available.
|
|
30
|
+
/// * fp16 version adds 4 extra bytes per encoded vector
|
|
31
|
+
/// * fp32 version adds 8 extra bytes per encoded vector
|
|
32
|
+
|
|
33
|
+
/// Provides base functions for rowwise normalizing indices.
|
|
34
|
+
struct IndexRowwiseMinMaxBase : Index {
|
|
35
|
+
/// sub-index
|
|
36
|
+
Index* index;
|
|
37
|
+
|
|
38
|
+
/// whether the subindex needs to be freed in the destructor.
|
|
39
|
+
bool own_fields;
|
|
40
|
+
|
|
41
|
+
explicit IndexRowwiseMinMaxBase(Index* index);
|
|
42
|
+
|
|
43
|
+
IndexRowwiseMinMaxBase();
|
|
44
|
+
~IndexRowwiseMinMaxBase() override;
|
|
45
|
+
|
|
46
|
+
void add(idx_t n, const float* x) override;
|
|
47
|
+
void search(
|
|
48
|
+
idx_t n,
|
|
49
|
+
const float* x,
|
|
50
|
+
idx_t k,
|
|
51
|
+
float* distances,
|
|
52
|
+
idx_t* labels,
|
|
53
|
+
const SearchParameters* params = nullptr) const override;
|
|
54
|
+
|
|
55
|
+
void reset() override;
|
|
56
|
+
|
|
57
|
+
virtual void train_inplace(idx_t n, float* x) = 0;
|
|
58
|
+
};
|
|
59
|
+
|
|
60
|
+
/// Stores scaling coefficients as fp16 values.
|
|
61
|
+
struct IndexRowwiseMinMaxFP16 : IndexRowwiseMinMaxBase {
|
|
62
|
+
explicit IndexRowwiseMinMaxFP16(Index* index);
|
|
63
|
+
|
|
64
|
+
IndexRowwiseMinMaxFP16();
|
|
65
|
+
|
|
66
|
+
void train(idx_t n, const float* x) override;
|
|
67
|
+
void train_inplace(idx_t n, float* x) override;
|
|
68
|
+
|
|
69
|
+
size_t sa_code_size() const override;
|
|
70
|
+
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
|
71
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
72
|
+
};
|
|
73
|
+
|
|
74
|
+
/// Stores scaling coefficients as fp32 values.
|
|
75
|
+
struct IndexRowwiseMinMax : IndexRowwiseMinMaxBase {
|
|
76
|
+
explicit IndexRowwiseMinMax(Index* index);
|
|
77
|
+
|
|
78
|
+
IndexRowwiseMinMax();
|
|
79
|
+
|
|
80
|
+
void train(idx_t n, const float* x) override;
|
|
81
|
+
void train_inplace(idx_t n, float* x) override;
|
|
82
|
+
|
|
83
|
+
size_t sa_code_size() const override;
|
|
84
|
+
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
|
85
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
|
86
|
+
};
|
|
87
|
+
|
|
88
|
+
/// block size for performing sa_encode and sa_decode
|
|
89
|
+
FAISS_API extern int rowwise_minmax_sa_encode_bs;
|
|
90
|
+
FAISS_API extern int rowwise_minmax_sa_decode_bs;
|
|
91
|
+
|
|
92
|
+
} // namespace faiss
|
|
@@ -16,6 +16,7 @@
|
|
|
16
16
|
|
|
17
17
|
#include <faiss/impl/AuxIndexStructures.h>
|
|
18
18
|
#include <faiss/impl/FaissAssert.h>
|
|
19
|
+
#include <faiss/impl/IDSelector.h>
|
|
19
20
|
#include <faiss/impl/ScalarQuantizer.h>
|
|
20
21
|
#include <faiss/utils/utils.h>
|
|
21
22
|
|
|
@@ -29,7 +30,7 @@ IndexScalarQuantizer::IndexScalarQuantizer(
|
|
|
29
30
|
int d,
|
|
30
31
|
ScalarQuantizer::QuantizerType qtype,
|
|
31
32
|
MetricType metric)
|
|
32
|
-
:
|
|
33
|
+
: IndexFlatCodes(0, d, metric), sq(d, qtype) {
|
|
33
34
|
is_trained = qtype == ScalarQuantizer::QT_fp16 ||
|
|
34
35
|
qtype == ScalarQuantizer::QT_8bit_direct;
|
|
35
36
|
code_size = sq.code_size;
|
|
@@ -43,21 +44,16 @@ void IndexScalarQuantizer::train(idx_t n, const float* x) {
|
|
|
43
44
|
is_trained = true;
|
|
44
45
|
}
|
|
45
46
|
|
|
46
|
-
void IndexScalarQuantizer::add(idx_t n, const float* x) {
|
|
47
|
-
FAISS_THROW_IF_NOT(is_trained);
|
|
48
|
-
codes.resize((n + ntotal) * code_size);
|
|
49
|
-
sq.compute_codes(x, &codes[ntotal * code_size], n);
|
|
50
|
-
ntotal += n;
|
|
51
|
-
}
|
|
52
|
-
|
|
53
47
|
void IndexScalarQuantizer::search(
|
|
54
48
|
idx_t n,
|
|
55
49
|
const float* x,
|
|
56
50
|
idx_t k,
|
|
57
51
|
float* distances,
|
|
58
|
-
idx_t* labels
|
|
59
|
-
|
|
52
|
+
idx_t* labels,
|
|
53
|
+
const SearchParameters* params) const {
|
|
54
|
+
const IDSelector* sel = params ? params->sel : nullptr;
|
|
60
55
|
|
|
56
|
+
FAISS_THROW_IF_NOT(k > 0);
|
|
61
57
|
FAISS_THROW_IF_NOT(is_trained);
|
|
62
58
|
FAISS_THROW_IF_NOT(
|
|
63
59
|
metric_type == METRIC_L2 || metric_type == METRIC_INNER_PRODUCT);
|
|
@@ -65,8 +61,10 @@ void IndexScalarQuantizer::search(
|
|
|
65
61
|
#pragma omp parallel
|
|
66
62
|
{
|
|
67
63
|
InvertedListScanner* scanner =
|
|
68
|
-
sq.select_InvertedListScanner(metric_type, nullptr, true);
|
|
64
|
+
sq.select_InvertedListScanner(metric_type, nullptr, true, sel);
|
|
65
|
+
|
|
69
66
|
ScopeDeleter1<InvertedListScanner> del(scanner);
|
|
67
|
+
scanner->list_no = 0; // directly the list number
|
|
70
68
|
|
|
71
69
|
#pragma omp for
|
|
72
70
|
for (idx_t i = 0; i < n; i++) {
|
|
@@ -91,7 +89,8 @@ void IndexScalarQuantizer::search(
|
|
|
91
89
|
}
|
|
92
90
|
}
|
|
93
91
|
|
|
94
|
-
|
|
92
|
+
FlatCodesDistanceComputer* IndexScalarQuantizer::get_FlatCodesDistanceComputer()
|
|
93
|
+
const {
|
|
95
94
|
ScalarQuantizer::SQDistanceComputer* dc =
|
|
96
95
|
sq.get_distance_computer(metric_type);
|
|
97
96
|
dc->code_size = sq.code_size;
|
|
@@ -99,27 +98,7 @@ DistanceComputer* IndexScalarQuantizer::get_distance_computer() const {
|
|
|
99
98
|
return dc;
|
|
100
99
|
}
|
|
101
100
|
|
|
102
|
-
void IndexScalarQuantizer::reset() {
|
|
103
|
-
codes.clear();
|
|
104
|
-
ntotal = 0;
|
|
105
|
-
}
|
|
106
|
-
|
|
107
|
-
void IndexScalarQuantizer::reconstruct_n(idx_t i0, idx_t ni, float* recons)
|
|
108
|
-
const {
|
|
109
|
-
std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
|
|
110
|
-
for (size_t i = 0; i < ni; i++) {
|
|
111
|
-
squant->decode_vector(&codes[(i + i0) * code_size], recons + i * d);
|
|
112
|
-
}
|
|
113
|
-
}
|
|
114
|
-
|
|
115
|
-
void IndexScalarQuantizer::reconstruct(idx_t key, float* recons) const {
|
|
116
|
-
reconstruct_n(key, 1, recons);
|
|
117
|
-
}
|
|
118
|
-
|
|
119
101
|
/* Codec interface */
|
|
120
|
-
size_t IndexScalarQuantizer::sa_code_size() const {
|
|
121
|
-
return sq.code_size;
|
|
122
|
-
}
|
|
123
102
|
|
|
124
103
|
void IndexScalarQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
|
|
125
104
|
const {
|
|
@@ -166,7 +145,7 @@ void IndexIVFScalarQuantizer::encode_vectors(
|
|
|
166
145
|
const idx_t* list_nos,
|
|
167
146
|
uint8_t* codes,
|
|
168
147
|
bool include_listnos) const {
|
|
169
|
-
std::unique_ptr<ScalarQuantizer::
|
|
148
|
+
std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
|
|
170
149
|
size_t coarse_size = include_listnos ? coarse_code_size() : 0;
|
|
171
150
|
memset(codes, 0, (code_size + coarse_size) * n);
|
|
172
151
|
|
|
@@ -195,7 +174,7 @@ void IndexIVFScalarQuantizer::encode_vectors(
|
|
|
195
174
|
|
|
196
175
|
void IndexIVFScalarQuantizer::sa_decode(idx_t n, const uint8_t* codes, float* x)
|
|
197
176
|
const {
|
|
198
|
-
std::unique_ptr<ScalarQuantizer::
|
|
177
|
+
std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
|
|
199
178
|
size_t coarse_size = coarse_code_size();
|
|
200
179
|
|
|
201
180
|
#pragma omp parallel if (n > 1000)
|
|
@@ -226,7 +205,7 @@ void IndexIVFScalarQuantizer::add_core(
|
|
|
226
205
|
FAISS_THROW_IF_NOT(is_trained);
|
|
227
206
|
|
|
228
207
|
size_t nadd = 0;
|
|
229
|
-
std::unique_ptr<ScalarQuantizer::
|
|
208
|
+
std::unique_ptr<ScalarQuantizer::SQuantizer> squant(sq.select_quantizer());
|
|
230
209
|
|
|
231
210
|
DirectMapAdd dm_add(direct_map, n, xids);
|
|
232
211
|
|
|
@@ -267,22 +246,28 @@ void IndexIVFScalarQuantizer::add_core(
|
|
|
267
246
|
}
|
|
268
247
|
|
|
269
248
|
InvertedListScanner* IndexIVFScalarQuantizer::get_InvertedListScanner(
|
|
270
|
-
bool store_pairs
|
|
249
|
+
bool store_pairs,
|
|
250
|
+
const IDSelector* sel) const {
|
|
271
251
|
return sq.select_InvertedListScanner(
|
|
272
|
-
metric_type, quantizer, store_pairs, by_residual);
|
|
252
|
+
metric_type, quantizer, store_pairs, sel, by_residual);
|
|
273
253
|
}
|
|
274
254
|
|
|
275
255
|
void IndexIVFScalarQuantizer::reconstruct_from_offset(
|
|
276
256
|
int64_t list_no,
|
|
277
257
|
int64_t offset,
|
|
278
258
|
float* recons) const {
|
|
279
|
-
std::vector<float> centroid(d);
|
|
280
|
-
quantizer->reconstruct(list_no, centroid.data());
|
|
281
|
-
|
|
282
259
|
const uint8_t* code = invlists->get_single_code(list_no, offset);
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
260
|
+
|
|
261
|
+
if (by_residual) {
|
|
262
|
+
std::vector<float> centroid(d);
|
|
263
|
+
quantizer->reconstruct(list_no, centroid.data());
|
|
264
|
+
|
|
265
|
+
sq.decode(code, recons, 1);
|
|
266
|
+
for (int i = 0; i < d; ++i) {
|
|
267
|
+
recons[i] += centroid[i];
|
|
268
|
+
}
|
|
269
|
+
} else {
|
|
270
|
+
sq.decode(code, recons, 1);
|
|
286
271
|
}
|
|
287
272
|
}
|
|
288
273
|
|