faiss 0.2.3 → 0.2.4
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
- data/vendor/faiss/faiss/Index2Layer.h +2 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
- data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
- data/vendor/faiss/faiss/IndexFlat.h +9 -15
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
- data/vendor/faiss/faiss/IndexIVF.h +25 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
- data/vendor/faiss/faiss/IndexLSH.h +2 -15
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
- data/vendor/faiss/faiss/IndexPQ.h +2 -17
- data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
- data/vendor/faiss/faiss/IndexRefine.h +10 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
- data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
- data/vendor/faiss/faiss/VectorTransform.h +3 -0
- data/vendor/faiss/faiss/clone_index.cpp +3 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
- data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
- data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/index_factory.cpp +585 -414
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/utils/distances.cpp +4 -2
- data/vendor/faiss/faiss/utils/distances.h +36 -3
- data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +12 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -357,6 +357,7 @@ PCAMatrix::PCAMatrix(
|
|
357
357
|
is_trained = false;
|
358
358
|
max_points_per_d = 1000;
|
359
359
|
balanced_bins = 0;
|
360
|
+
epsilon = 0;
|
360
361
|
}
|
361
362
|
|
362
363
|
namespace {
|
@@ -620,7 +621,7 @@ void PCAMatrix::prepare_Ab() {
|
|
620
621
|
if (eigen_power != 0) {
|
621
622
|
float* ai = A.data();
|
622
623
|
for (int i = 0; i < d_out; i++) {
|
623
|
-
float factor = pow(eigenvalues[i], eigen_power);
|
624
|
+
float factor = pow(eigenvalues[i] + epsilon, eigen_power);
|
624
625
|
for (int j = 0; j < d_in; j++)
|
625
626
|
*ai++ *= factor;
|
626
627
|
}
|
@@ -15,6 +15,7 @@
|
|
15
15
|
#include <faiss/impl/FaissAssert.h>
|
16
16
|
|
17
17
|
#include <faiss/Index2Layer.h>
|
18
|
+
#include <faiss/IndexAdditiveQuantizer.h>
|
18
19
|
#include <faiss/IndexFlat.h>
|
19
20
|
#include <faiss/IndexHNSW.h>
|
20
21
|
#include <faiss/IndexIVF.h>
|
@@ -27,7 +28,6 @@
|
|
27
28
|
#include <faiss/IndexNSG.h>
|
28
29
|
#include <faiss/IndexPQ.h>
|
29
30
|
#include <faiss/IndexPreTransform.h>
|
30
|
-
#include <faiss/IndexResidual.h>
|
31
31
|
#include <faiss/IndexScalarQuantizer.h>
|
32
32
|
#include <faiss/MetaIndexes.h>
|
33
33
|
#include <faiss/VectorTransform.h>
|
@@ -80,9 +80,10 @@ Index* Cloner::clone_Index(const Index* index) {
|
|
80
80
|
TRYCLONE(IndexFlatIP, index)
|
81
81
|
TRYCLONE(IndexFlat, index)
|
82
82
|
TRYCLONE(IndexLattice, index)
|
83
|
-
TRYCLONE(
|
83
|
+
TRYCLONE(IndexResidualQuantizer, index)
|
84
84
|
TRYCLONE(IndexScalarQuantizer, index)
|
85
85
|
TRYCLONE(MultiIndexQuantizer, index)
|
86
|
+
TRYCLONE(ResidualCoarseQuantizer, index)
|
86
87
|
if (const IndexIVF* ivf = dynamic_cast<const IndexIVF*>(index)) {
|
87
88
|
IndexIVF* res = clone_IndexIVF(ivf);
|
88
89
|
if (ivf->invlists == nullptr) {
|
@@ -40,7 +40,7 @@ void ToCPUCloner::merge_index(Index* dst, Index* src, bool successive_ids) {
|
|
40
40
|
auto ifl2 = dynamic_cast<const IndexFlat*>(src);
|
41
41
|
FAISS_ASSERT(ifl2);
|
42
42
|
FAISS_ASSERT(successive_ids);
|
43
|
-
ifl->add(ifl2->ntotal, ifl2->
|
43
|
+
ifl->add(ifl2->ntotal, ifl2->get_xb());
|
44
44
|
} else if (auto ifl = dynamic_cast<IndexIVFFlat*>(dst)) {
|
45
45
|
auto ifl2 = dynamic_cast<IndexIVFFlat*>(src);
|
46
46
|
FAISS_ASSERT(ifl2);
|
@@ -329,7 +329,7 @@ Index* ToGpuClonerMultiple::clone_Index_to_shards(const Index* index) {
|
|
329
329
|
if (index->ntotal > 0) {
|
330
330
|
long i0 = index->ntotal * i / n;
|
331
331
|
long i1 = index->ntotal * (i + 1) / n;
|
332
|
-
shards[i]->add(i1 - i0, index_flat->
|
332
|
+
shards[i]->add(i1 - i0, index_flat->get_xb() + i0 * index->d);
|
333
333
|
}
|
334
334
|
}
|
335
335
|
}
|
@@ -0,0 +1,60 @@
|
|
1
|
+
/**
|
2
|
+
* Copyright (c) Facebook, Inc. and its affiliates.
|
3
|
+
*
|
4
|
+
* This source code is licensed under the MIT license found in the
|
5
|
+
* LICENSE file in the root directory of this source tree.
|
6
|
+
*/
|
7
|
+
|
8
|
+
#pragma once
|
9
|
+
|
10
|
+
#include <faiss/impl/LocalSearchQuantizer.h>
|
11
|
+
|
12
|
+
#include <memory>
|
13
|
+
|
14
|
+
namespace faiss {
|
15
|
+
namespace gpu {
|
16
|
+
|
17
|
+
class GpuResourcesProvider;
|
18
|
+
struct IcmEncoderShards;
|
19
|
+
|
20
|
+
/** Perform LSQ encoding on GPU.
|
21
|
+
*
|
22
|
+
* Split input vectors to different devices and call IcmEncoderImpl::encode
|
23
|
+
* to encode them
|
24
|
+
*/
|
25
|
+
class GpuIcmEncoder : public lsq::IcmEncoder {
|
26
|
+
public:
|
27
|
+
GpuIcmEncoder(
|
28
|
+
const LocalSearchQuantizer* lsq,
|
29
|
+
const std::vector<GpuResourcesProvider*>& provs,
|
30
|
+
const std::vector<int>& devices);
|
31
|
+
|
32
|
+
~GpuIcmEncoder();
|
33
|
+
|
34
|
+
GpuIcmEncoder(const GpuIcmEncoder&) = delete;
|
35
|
+
GpuIcmEncoder& operator=(const GpuIcmEncoder&) = delete;
|
36
|
+
|
37
|
+
void set_binary_term() override;
|
38
|
+
|
39
|
+
void encode(
|
40
|
+
int32_t* codes,
|
41
|
+
const float* x,
|
42
|
+
std::mt19937& gen,
|
43
|
+
size_t n,
|
44
|
+
size_t ils_iters) const override;
|
45
|
+
|
46
|
+
private:
|
47
|
+
std::unique_ptr<IcmEncoderShards> shards;
|
48
|
+
};
|
49
|
+
|
50
|
+
struct GpuIcmEncoderFactory : public lsq::IcmEncoderFactory {
|
51
|
+
explicit GpuIcmEncoderFactory(int ngpus = 1);
|
52
|
+
|
53
|
+
lsq::IcmEncoder* get(const LocalSearchQuantizer* lsq) override;
|
54
|
+
|
55
|
+
std::vector<GpuResourcesProvider*> provs;
|
56
|
+
std::vector<int> devices;
|
57
|
+
};
|
58
|
+
|
59
|
+
} // namespace gpu
|
60
|
+
} // namespace faiss
|
@@ -8,7 +8,6 @@
|
|
8
8
|
// -*- c++ -*-
|
9
9
|
|
10
10
|
#include <faiss/impl/AdditiveQuantizer.h>
|
11
|
-
#include <faiss/impl/FaissAssert.h>
|
12
11
|
|
13
12
|
#include <cstddef>
|
14
13
|
#include <cstdio>
|
@@ -18,9 +17,10 @@
|
|
18
17
|
|
19
18
|
#include <algorithm>
|
20
19
|
|
20
|
+
#include <faiss/impl/FaissAssert.h>
|
21
21
|
#include <faiss/utils/Heap.h>
|
22
22
|
#include <faiss/utils/distances.h>
|
23
|
-
#include <faiss/utils/hamming.h>
|
23
|
+
#include <faiss/utils/hamming.h>
|
24
24
|
#include <faiss/utils/utils.h>
|
25
25
|
|
26
26
|
extern "C" {
|
@@ -42,51 +42,125 @@ int sgemm_(
|
|
42
42
|
FINTEGER* ldc);
|
43
43
|
}
|
44
44
|
|
45
|
-
namespace {
|
46
|
-
|
47
|
-
// c and a and b can overlap
|
48
|
-
void fvec_add(size_t d, const float* a, const float* b, float* c) {
|
49
|
-
for (size_t i = 0; i < d; i++) {
|
50
|
-
c[i] = a[i] + b[i];
|
51
|
-
}
|
52
|
-
}
|
45
|
+
namespace faiss {
|
53
46
|
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
47
|
+
AdditiveQuantizer::AdditiveQuantizer(
|
48
|
+
size_t d,
|
49
|
+
const std::vector<size_t>& nbits,
|
50
|
+
Search_type_t search_type)
|
51
|
+
: d(d),
|
52
|
+
M(nbits.size()),
|
53
|
+
nbits(nbits),
|
54
|
+
verbose(false),
|
55
|
+
is_trained(false),
|
56
|
+
search_type(search_type) {
|
57
|
+
norm_max = norm_min = NAN;
|
58
|
+
code_size = 0;
|
59
|
+
tot_bits = 0;
|
60
|
+
total_codebook_size = 0;
|
61
|
+
only_8bit = false;
|
62
|
+
set_derived_values();
|
58
63
|
}
|
59
64
|
|
60
|
-
|
61
|
-
|
62
|
-
namespace faiss {
|
65
|
+
AdditiveQuantizer::AdditiveQuantizer()
|
66
|
+
: AdditiveQuantizer(0, std::vector<size_t>()) {}
|
63
67
|
|
64
68
|
void AdditiveQuantizer::set_derived_values() {
|
65
69
|
tot_bits = 0;
|
66
|
-
|
70
|
+
only_8bit = true;
|
67
71
|
codebook_offsets.resize(M + 1, 0);
|
68
72
|
for (int i = 0; i < M; i++) {
|
69
73
|
int nbit = nbits[i];
|
70
74
|
size_t k = 1 << nbit;
|
71
75
|
codebook_offsets[i + 1] = codebook_offsets[i] + k;
|
72
76
|
tot_bits += nbit;
|
73
|
-
if (nbit
|
74
|
-
|
77
|
+
if (nbit != 0) {
|
78
|
+
only_8bit = false;
|
75
79
|
}
|
76
80
|
}
|
77
81
|
total_codebook_size = codebook_offsets[M];
|
82
|
+
switch (search_type) {
|
83
|
+
case ST_decompress:
|
84
|
+
case ST_LUT_nonorm:
|
85
|
+
case ST_norm_from_LUT:
|
86
|
+
break; // nothing to add
|
87
|
+
case ST_norm_float:
|
88
|
+
tot_bits += 32;
|
89
|
+
break;
|
90
|
+
case ST_norm_qint8:
|
91
|
+
case ST_norm_cqint8:
|
92
|
+
tot_bits += 8;
|
93
|
+
break;
|
94
|
+
case ST_norm_qint4:
|
95
|
+
case ST_norm_cqint4:
|
96
|
+
tot_bits += 4;
|
97
|
+
break;
|
98
|
+
}
|
99
|
+
|
78
100
|
// convert bits to bytes
|
79
101
|
code_size = (tot_bits + 7) / 8;
|
80
102
|
}
|
81
103
|
|
104
|
+
namespace {
|
105
|
+
|
106
|
+
// TODO
|
107
|
+
// https://stackoverflow.com/questions/31631224/hacks-for-clamping-integer-to-0-255-and-doubles-to-0-0-1-0
|
108
|
+
|
109
|
+
uint8_t encode_qint8(float x, float amin, float amax) {
|
110
|
+
float x1 = (x - amin) / (amax - amin) * 256;
|
111
|
+
int32_t xi = int32_t(floor(x1));
|
112
|
+
|
113
|
+
return xi < 0 ? 0 : xi > 255 ? 255 : xi;
|
114
|
+
}
|
115
|
+
|
116
|
+
uint8_t encode_qint4(float x, float amin, float amax) {
|
117
|
+
float x1 = (x - amin) / (amax - amin) * 16;
|
118
|
+
int32_t xi = int32_t(floor(x1));
|
119
|
+
|
120
|
+
return xi < 0 ? 0 : xi > 15 ? 15 : xi;
|
121
|
+
}
|
122
|
+
|
123
|
+
float decode_qint8(uint8_t i, float amin, float amax) {
|
124
|
+
return (i + 0.5) / 256 * (amax - amin) + amin;
|
125
|
+
}
|
126
|
+
|
127
|
+
float decode_qint4(uint8_t i, float amin, float amax) {
|
128
|
+
return (i + 0.5) / 16 * (amax - amin) + amin;
|
129
|
+
}
|
130
|
+
|
131
|
+
} // anonymous namespace
|
132
|
+
|
133
|
+
uint32_t AdditiveQuantizer::encode_qcint(float x) const {
|
134
|
+
idx_t id;
|
135
|
+
qnorm.assign(idx_t(1), &x, &id, idx_t(1));
|
136
|
+
return uint32_t(id);
|
137
|
+
}
|
138
|
+
|
139
|
+
float AdditiveQuantizer::decode_qcint(uint32_t c) const {
|
140
|
+
return qnorm.get_xb()[c];
|
141
|
+
}
|
142
|
+
|
82
143
|
void AdditiveQuantizer::pack_codes(
|
83
144
|
size_t n,
|
84
145
|
const int32_t* codes,
|
85
146
|
uint8_t* packed_codes,
|
86
|
-
int64_t ld_codes
|
147
|
+
int64_t ld_codes,
|
148
|
+
const float* norms) const {
|
87
149
|
if (ld_codes == -1) {
|
88
150
|
ld_codes = M;
|
89
151
|
}
|
152
|
+
std::vector<float> norm_buf;
|
153
|
+
if (search_type == ST_norm_float || search_type == ST_norm_qint4 ||
|
154
|
+
search_type == ST_norm_qint8 || search_type == ST_norm_cqint8 ||
|
155
|
+
search_type == ST_norm_cqint4) {
|
156
|
+
if (!norms) {
|
157
|
+
norm_buf.resize(n);
|
158
|
+
std::vector<float> x_recons(n * d);
|
159
|
+
decode_unpacked(codes, x_recons.data(), n, ld_codes);
|
160
|
+
fvec_norms_L2sqr(norm_buf.data(), x_recons.data(), d, n);
|
161
|
+
norms = norm_buf.data();
|
162
|
+
}
|
163
|
+
}
|
90
164
|
#pragma omp parallel for if (n > 1000)
|
91
165
|
for (int64_t i = 0; i < n; i++) {
|
92
166
|
const int32_t* codes1 = codes + i * ld_codes;
|
@@ -94,6 +168,35 @@ void AdditiveQuantizer::pack_codes(
|
|
94
168
|
for (int m = 0; m < M; m++) {
|
95
169
|
bsw.write(codes1[m], nbits[m]);
|
96
170
|
}
|
171
|
+
switch (search_type) {
|
172
|
+
case ST_decompress:
|
173
|
+
case ST_LUT_nonorm:
|
174
|
+
case ST_norm_from_LUT:
|
175
|
+
break;
|
176
|
+
case ST_norm_float:
|
177
|
+
bsw.write(*(uint32_t*)&norms[i], 32);
|
178
|
+
break;
|
179
|
+
case ST_norm_qint8: {
|
180
|
+
uint8_t b = encode_qint8(norms[i], norm_min, norm_max);
|
181
|
+
bsw.write(b, 8);
|
182
|
+
break;
|
183
|
+
}
|
184
|
+
case ST_norm_qint4: {
|
185
|
+
uint8_t b = encode_qint4(norms[i], norm_min, norm_max);
|
186
|
+
bsw.write(b, 4);
|
187
|
+
break;
|
188
|
+
}
|
189
|
+
case ST_norm_cqint8: {
|
190
|
+
uint32_t b = encode_qcint(norms[i]);
|
191
|
+
bsw.write(b, 8);
|
192
|
+
break;
|
193
|
+
}
|
194
|
+
case ST_norm_cqint4: {
|
195
|
+
uint32_t b = encode_qcint(norms[i]);
|
196
|
+
bsw.write(b, 4);
|
197
|
+
break;
|
198
|
+
}
|
199
|
+
}
|
97
200
|
}
|
98
201
|
}
|
99
202
|
|
@@ -118,10 +221,39 @@ void AdditiveQuantizer::decode(const uint8_t* code, float* x, size_t n) const {
|
|
118
221
|
}
|
119
222
|
}
|
120
223
|
|
224
|
+
void AdditiveQuantizer::decode_unpacked(
|
225
|
+
const int32_t* code,
|
226
|
+
float* x,
|
227
|
+
size_t n,
|
228
|
+
int64_t ld_codes) const {
|
229
|
+
FAISS_THROW_IF_NOT_MSG(
|
230
|
+
is_trained, "The additive quantizer is not trained yet.");
|
231
|
+
|
232
|
+
if (ld_codes == -1) {
|
233
|
+
ld_codes = M;
|
234
|
+
}
|
235
|
+
|
236
|
+
// standard additive quantizer decoding
|
237
|
+
#pragma omp parallel for if (n > 1000)
|
238
|
+
for (int64_t i = 0; i < n; i++) {
|
239
|
+
const int32_t* codesi = code + i * ld_codes;
|
240
|
+
float* xi = x + i * d;
|
241
|
+
for (int m = 0; m < M; m++) {
|
242
|
+
int idx = codesi[m];
|
243
|
+
const float* c = codebooks.data() + d * (codebook_offsets[m] + idx);
|
244
|
+
if (m == 0) {
|
245
|
+
memcpy(xi, c, sizeof(*x) * d);
|
246
|
+
} else {
|
247
|
+
fvec_add(d, xi, c, xi);
|
248
|
+
}
|
249
|
+
}
|
250
|
+
}
|
251
|
+
}
|
252
|
+
|
121
253
|
AdditiveQuantizer::~AdditiveQuantizer() {}
|
122
254
|
|
123
255
|
/****************************************************************************
|
124
|
-
* Support for fast distance computations
|
256
|
+
* Support for fast distance computations in centroids
|
125
257
|
****************************************************************************/
|
126
258
|
|
127
259
|
void AdditiveQuantizer::compute_centroid_norms(float* norms) const {
|
@@ -201,7 +333,7 @@ void compute_inner_prod_with_LUT(
|
|
201
333
|
|
202
334
|
} // anonymous namespace
|
203
335
|
|
204
|
-
void AdditiveQuantizer::
|
336
|
+
void AdditiveQuantizer::knn_centroids_inner_product(
|
205
337
|
idx_t n,
|
206
338
|
const float* xq,
|
207
339
|
idx_t k,
|
@@ -227,7 +359,7 @@ void AdditiveQuantizer::knn_exact_inner_product(
|
|
227
359
|
}
|
228
360
|
}
|
229
361
|
|
230
|
-
void AdditiveQuantizer::
|
362
|
+
void AdditiveQuantizer::knn_centroids_L2(
|
231
363
|
idx_t n,
|
232
364
|
const float* xq,
|
233
365
|
idx_t k,
|
@@ -267,4 +399,105 @@ void AdditiveQuantizer::knn_exact_L2(
|
|
267
399
|
}
|
268
400
|
}
|
269
401
|
|
402
|
+
/****************************************************************************
|
403
|
+
* Support for fast distance computations in codes
|
404
|
+
****************************************************************************/
|
405
|
+
|
406
|
+
namespace {
|
407
|
+
|
408
|
+
float accumulate_IPs(
|
409
|
+
const AdditiveQuantizer& aq,
|
410
|
+
BitstringReader& bs,
|
411
|
+
const uint8_t* codes,
|
412
|
+
const float* LUT) {
|
413
|
+
float accu = 0;
|
414
|
+
for (int m = 0; m < aq.M; m++) {
|
415
|
+
size_t nbit = aq.nbits[m];
|
416
|
+
int idx = bs.read(nbit);
|
417
|
+
accu += LUT[idx];
|
418
|
+
LUT += (uint64_t)1 << nbit;
|
419
|
+
}
|
420
|
+
return accu;
|
421
|
+
}
|
422
|
+
|
423
|
+
} // anonymous namespace
|
424
|
+
|
425
|
+
template <>
|
426
|
+
float AdditiveQuantizer::
|
427
|
+
compute_1_distance_LUT<true, AdditiveQuantizer::ST_LUT_nonorm>(
|
428
|
+
const uint8_t* codes,
|
429
|
+
const float* LUT) const {
|
430
|
+
BitstringReader bs(codes, code_size);
|
431
|
+
return accumulate_IPs(*this, bs, codes, LUT);
|
432
|
+
}
|
433
|
+
|
434
|
+
template <>
|
435
|
+
float AdditiveQuantizer::
|
436
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_LUT_nonorm>(
|
437
|
+
const uint8_t* codes,
|
438
|
+
const float* LUT) const {
|
439
|
+
BitstringReader bs(codes, code_size);
|
440
|
+
return -accumulate_IPs(*this, bs, codes, LUT);
|
441
|
+
}
|
442
|
+
|
443
|
+
template <>
|
444
|
+
float AdditiveQuantizer::
|
445
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_float>(
|
446
|
+
const uint8_t* codes,
|
447
|
+
const float* LUT) const {
|
448
|
+
BitstringReader bs(codes, code_size);
|
449
|
+
float accu = accumulate_IPs(*this, bs, codes, LUT);
|
450
|
+
uint32_t norm_i = bs.read(32);
|
451
|
+
float norm2 = *(float*)&norm_i;
|
452
|
+
return norm2 - 2 * accu;
|
453
|
+
}
|
454
|
+
|
455
|
+
template <>
|
456
|
+
float AdditiveQuantizer::
|
457
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint8>(
|
458
|
+
const uint8_t* codes,
|
459
|
+
const float* LUT) const {
|
460
|
+
BitstringReader bs(codes, code_size);
|
461
|
+
float accu = accumulate_IPs(*this, bs, codes, LUT);
|
462
|
+
uint32_t norm_i = bs.read(8);
|
463
|
+
float norm2 = decode_qcint(norm_i);
|
464
|
+
return norm2 - 2 * accu;
|
465
|
+
}
|
466
|
+
|
467
|
+
template <>
|
468
|
+
float AdditiveQuantizer::
|
469
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_cqint4>(
|
470
|
+
const uint8_t* codes,
|
471
|
+
const float* LUT) const {
|
472
|
+
BitstringReader bs(codes, code_size);
|
473
|
+
float accu = accumulate_IPs(*this, bs, codes, LUT);
|
474
|
+
uint32_t norm_i = bs.read(4);
|
475
|
+
float norm2 = decode_qcint(norm_i);
|
476
|
+
return norm2 - 2 * accu;
|
477
|
+
}
|
478
|
+
|
479
|
+
template <>
|
480
|
+
float AdditiveQuantizer::
|
481
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint8>(
|
482
|
+
const uint8_t* codes,
|
483
|
+
const float* LUT) const {
|
484
|
+
BitstringReader bs(codes, code_size);
|
485
|
+
float accu = accumulate_IPs(*this, bs, codes, LUT);
|
486
|
+
uint32_t norm_i = bs.read(8);
|
487
|
+
float norm2 = decode_qint8(norm_i, norm_min, norm_max);
|
488
|
+
return norm2 - 2 * accu;
|
489
|
+
}
|
490
|
+
|
491
|
+
template <>
|
492
|
+
float AdditiveQuantizer::
|
493
|
+
compute_1_distance_LUT<false, AdditiveQuantizer::ST_norm_qint4>(
|
494
|
+
const uint8_t* codes,
|
495
|
+
const float* LUT) const {
|
496
|
+
BitstringReader bs(codes, code_size);
|
497
|
+
float accu = accumulate_IPs(*this, bs, codes, LUT);
|
498
|
+
uint32_t norm_i = bs.read(4);
|
499
|
+
float norm2 = decode_qint4(norm_i, norm_min, norm_max);
|
500
|
+
return norm2 - 2 * accu;
|
501
|
+
}
|
502
|
+
|
270
503
|
} // namespace faiss
|
@@ -11,6 +11,7 @@
|
|
11
11
|
#include <vector>
|
12
12
|
|
13
13
|
#include <faiss/Index.h>
|
14
|
+
#include <faiss/IndexFlat.h>
|
14
15
|
|
15
16
|
namespace faiss {
|
16
17
|
|
@@ -27,15 +28,44 @@ struct AdditiveQuantizer {
|
|
27
28
|
std::vector<float> codebooks; ///< codebooks
|
28
29
|
|
29
30
|
// derived values
|
30
|
-
std::vector<
|
31
|
+
std::vector<uint64_t> codebook_offsets;
|
31
32
|
size_t code_size; ///< code size in bytes
|
32
33
|
size_t tot_bits; ///< total number of bits
|
33
34
|
size_t total_codebook_size; ///< size of the codebook in vectors
|
34
|
-
bool
|
35
|
+
bool only_8bit; ///< are all nbits = 8 (use faster decoder)
|
35
36
|
|
36
37
|
bool verbose; ///< verbose during training?
|
37
38
|
bool is_trained; ///< is trained or not
|
38
39
|
|
40
|
+
IndexFlat1D qnorm; ///< store and search norms
|
41
|
+
|
42
|
+
uint32_t encode_qcint(
|
43
|
+
float x) const; ///< encode norm by non-uniform scalar quantization
|
44
|
+
|
45
|
+
float decode_qcint(uint32_t c)
|
46
|
+
const; ///< decode norm by non-uniform scalar quantization
|
47
|
+
|
48
|
+
/// Encodes how search is performed and how vectors are encoded
|
49
|
+
enum Search_type_t {
|
50
|
+
ST_decompress, ///< decompress database vector
|
51
|
+
ST_LUT_nonorm, ///< use a LUT, don't include norms (OK for IP or
|
52
|
+
///< normalized vectors)
|
53
|
+
ST_norm_from_LUT, ///< compute the norms from the look-up tables (cost
|
54
|
+
///< is in O(M^2))
|
55
|
+
ST_norm_float, ///< use a LUT, and store float32 norm with the vectors
|
56
|
+
ST_norm_qint8, ///< use a LUT, and store 8bit-quantized norm
|
57
|
+
ST_norm_qint4,
|
58
|
+
ST_norm_cqint8, ///< use a LUT, and store non-uniform quantized norm
|
59
|
+
ST_norm_cqint4,
|
60
|
+
};
|
61
|
+
|
62
|
+
AdditiveQuantizer(
|
63
|
+
size_t d,
|
64
|
+
const std::vector<size_t>& nbits,
|
65
|
+
Search_type_t search_type = ST_decompress);
|
66
|
+
|
67
|
+
AdditiveQuantizer();
|
68
|
+
|
39
69
|
///< compute derived values when d, M and nbits have been set
|
40
70
|
void set_derived_values();
|
41
71
|
|
@@ -52,15 +82,18 @@ struct AdditiveQuantizer {
|
|
52
82
|
|
53
83
|
/** pack a series of code to bit-compact format
|
54
84
|
*
|
55
|
-
* @param codes
|
85
|
+
* @param codes codes to be packed, size n * code_size
|
56
86
|
* @param packed_codes output bit-compact codes
|
57
|
-
* @param ld_codes
|
87
|
+
* @param ld_codes leading dimension of codes
|
88
|
+
* @param norms norms of the vectors (size n). Will be computed if
|
89
|
+
* needed but not provided
|
58
90
|
*/
|
59
91
|
void pack_codes(
|
60
92
|
size_t n,
|
61
93
|
const int32_t* codes,
|
62
94
|
uint8_t* packed_codes,
|
63
|
-
int64_t ld_codes = -1
|
95
|
+
int64_t ld_codes = -1,
|
96
|
+
const float* norms = nullptr) const;
|
64
97
|
|
65
98
|
/** Decode a set of vectors
|
66
99
|
*
|
@@ -69,9 +102,36 @@ struct AdditiveQuantizer {
|
|
69
102
|
*/
|
70
103
|
void decode(const uint8_t* codes, float* x, size_t n) const;
|
71
104
|
|
105
|
+
/** Decode a set of vectors in non-packed format
|
106
|
+
*
|
107
|
+
* @param codes codes to decode, size n * ld_codes
|
108
|
+
* @param x output vectors, size n * d
|
109
|
+
*/
|
110
|
+
void decode_unpacked(
|
111
|
+
const int32_t* codes,
|
112
|
+
float* x,
|
113
|
+
size_t n,
|
114
|
+
int64_t ld_codes = -1) const;
|
115
|
+
|
116
|
+
/****************************************************************************
|
117
|
+
* Search functions in an external set of codes.
|
118
|
+
****************************************************************************/
|
119
|
+
|
120
|
+
/// Also determines what's in the codes
|
121
|
+
Search_type_t search_type;
|
122
|
+
|
123
|
+
/// min/max for quantization of norms
|
124
|
+
float norm_min, norm_max;
|
125
|
+
|
126
|
+
template <bool is_IP, Search_type_t effective_search_type>
|
127
|
+
float compute_1_distance_LUT(const uint8_t* codes, const float* LUT) const;
|
128
|
+
|
129
|
+
/*
|
130
|
+
float compute_1_L2sqr(const uint8_t* codes, const float* LUT);
|
131
|
+
*/
|
72
132
|
/****************************************************************************
|
73
|
-
* Support for exhaustive distance computations with the centroids.
|
74
|
-
* Hence, the number of
|
133
|
+
* Support for exhaustive distance computations with all the centroids.
|
134
|
+
* Hence, the number of these centroids should not be too large.
|
75
135
|
****************************************************************************/
|
76
136
|
using idx_t = Index::idx_t;
|
77
137
|
|
@@ -87,7 +147,7 @@ struct AdditiveQuantizer {
|
|
87
147
|
void compute_LUT(size_t n, const float* xq, float* LUT) const;
|
88
148
|
|
89
149
|
/// exact IP search
|
90
|
-
void
|
150
|
+
void knn_centroids_inner_product(
|
91
151
|
idx_t n,
|
92
152
|
const float* xq,
|
93
153
|
idx_t k,
|
@@ -101,7 +161,7 @@ struct AdditiveQuantizer {
|
|
101
161
|
void compute_centroid_norms(float* norms) const;
|
102
162
|
|
103
163
|
/** Exact L2 search, with precomputed norms */
|
104
|
-
void
|
164
|
+
void knn_centroids_L2(
|
105
165
|
idx_t n,
|
106
166
|
const float* xq,
|
107
167
|
idx_t k,
|
@@ -434,17 +434,22 @@ void HNSW::add_links_starting_from(
|
|
434
434
|
|
435
435
|
::faiss::shrink_neighbor_list(ptdis, link_targets, M);
|
436
436
|
|
437
|
+
std::vector<storage_idx_t> neighbors;
|
438
|
+
neighbors.reserve(link_targets.size());
|
437
439
|
while (!link_targets.empty()) {
|
438
|
-
|
440
|
+
storage_idx_t other_id = link_targets.top().id;
|
441
|
+
add_link(*this, ptdis, pt_id, other_id, level);
|
442
|
+
neighbors.push_back(other_id);
|
443
|
+
link_targets.pop();
|
444
|
+
}
|
439
445
|
|
446
|
+
omp_unset_lock(&locks[pt_id]);
|
447
|
+
for (storage_idx_t other_id : neighbors) {
|
440
448
|
omp_set_lock(&locks[other_id]);
|
441
449
|
add_link(*this, ptdis, other_id, pt_id, level);
|
442
450
|
omp_unset_lock(&locks[other_id]);
|
443
|
-
|
444
|
-
add_link(*this, ptdis, pt_id, other_id, level);
|
445
|
-
|
446
|
-
link_targets.pop();
|
447
451
|
}
|
452
|
+
omp_set_lock(&locks[pt_id]);
|
448
453
|
}
|
449
454
|
|
450
455
|
/**************************************************************
|