faiss 0.2.3 → 0.2.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/LICENSE.txt +1 -1
- data/lib/faiss/version.rb +1 -1
- data/vendor/faiss/faiss/Clustering.cpp +32 -0
- data/vendor/faiss/faiss/Clustering.h +14 -0
- data/vendor/faiss/faiss/Index.h +1 -1
- data/vendor/faiss/faiss/Index2Layer.cpp +19 -92
- data/vendor/faiss/faiss/Index2Layer.h +2 -16
- data/vendor/faiss/faiss/IndexAdditiveQuantizer.cpp +407 -0
- data/vendor/faiss/faiss/{IndexResidual.h → IndexAdditiveQuantizer.h} +101 -58
- data/vendor/faiss/faiss/IndexFlat.cpp +22 -52
- data/vendor/faiss/faiss/IndexFlat.h +9 -15
- data/vendor/faiss/faiss/IndexFlatCodes.cpp +67 -0
- data/vendor/faiss/faiss/IndexFlatCodes.h +47 -0
- data/vendor/faiss/faiss/IndexIVF.cpp +79 -7
- data/vendor/faiss/faiss/IndexIVF.h +25 -7
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.cpp +316 -0
- data/vendor/faiss/faiss/IndexIVFAdditiveQuantizer.h +121 -0
- data/vendor/faiss/faiss/IndexIVFFlat.cpp +9 -12
- data/vendor/faiss/faiss/IndexIVFPQ.cpp +5 -4
- data/vendor/faiss/faiss/IndexIVFPQ.h +1 -1
- data/vendor/faiss/faiss/IndexIVFSpectralHash.cpp +60 -39
- data/vendor/faiss/faiss/IndexIVFSpectralHash.h +21 -6
- data/vendor/faiss/faiss/IndexLSH.cpp +4 -30
- data/vendor/faiss/faiss/IndexLSH.h +2 -15
- data/vendor/faiss/faiss/IndexNNDescent.cpp +0 -2
- data/vendor/faiss/faiss/IndexNSG.cpp +0 -2
- data/vendor/faiss/faiss/IndexPQ.cpp +2 -51
- data/vendor/faiss/faiss/IndexPQ.h +2 -17
- data/vendor/faiss/faiss/IndexRefine.cpp +28 -0
- data/vendor/faiss/faiss/IndexRefine.h +10 -0
- data/vendor/faiss/faiss/IndexScalarQuantizer.cpp +2 -28
- data/vendor/faiss/faiss/IndexScalarQuantizer.h +2 -16
- data/vendor/faiss/faiss/VectorTransform.cpp +2 -1
- data/vendor/faiss/faiss/VectorTransform.h +3 -0
- data/vendor/faiss/faiss/clone_index.cpp +3 -2
- data/vendor/faiss/faiss/gpu/GpuCloner.cpp +2 -2
- data/vendor/faiss/faiss/gpu/GpuIcmEncoder.h +60 -0
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.cpp +257 -24
- data/vendor/faiss/faiss/impl/AdditiveQuantizer.h +69 -9
- data/vendor/faiss/faiss/impl/HNSW.cpp +10 -5
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.cpp +393 -210
- data/vendor/faiss/faiss/impl/LocalSearchQuantizer.h +100 -28
- data/vendor/faiss/faiss/impl/NSG.cpp +0 -3
- data/vendor/faiss/faiss/impl/NSG.h +1 -1
- data/vendor/faiss/faiss/impl/ResidualQuantizer.cpp +357 -47
- data/vendor/faiss/faiss/impl/ResidualQuantizer.h +65 -7
- data/vendor/faiss/faiss/impl/ScalarQuantizer.cpp +12 -19
- data/vendor/faiss/faiss/impl/index_read.cpp +102 -19
- data/vendor/faiss/faiss/impl/index_write.cpp +66 -16
- data/vendor/faiss/faiss/impl/io.cpp +1 -1
- data/vendor/faiss/faiss/impl/io_macros.h +20 -0
- data/vendor/faiss/faiss/impl/kmeans1d.cpp +301 -0
- data/vendor/faiss/faiss/impl/kmeans1d.h +48 -0
- data/vendor/faiss/faiss/index_factory.cpp +585 -414
- data/vendor/faiss/faiss/index_factory.h +3 -0
- data/vendor/faiss/faiss/utils/distances.cpp +4 -2
- data/vendor/faiss/faiss/utils/distances.h +36 -3
- data/vendor/faiss/faiss/utils/distances_simd.cpp +50 -0
- data/vendor/faiss/faiss/utils/utils.h +1 -1
- metadata +12 -5
- data/vendor/faiss/faiss/IndexResidual.cpp +0 -291
@@ -13,6 +13,8 @@
|
|
13
13
|
#include <algorithm>
|
14
14
|
#include <memory>
|
15
15
|
|
16
|
+
#include <faiss/IndexLSH.h>
|
17
|
+
#include <faiss/IndexPreTransform.h>
|
16
18
|
#include <faiss/VectorTransform.h>
|
17
19
|
#include <faiss/impl/AuxIndexStructures.h>
|
18
20
|
#include <faiss/impl/FaissAssert.h>
|
@@ -31,7 +33,6 @@ IndexIVFSpectralHash::IndexIVFSpectralHash(
|
|
31
33
|
nbit(nbit),
|
32
34
|
period(period),
|
33
35
|
threshold_type(Thresh_global) {
|
34
|
-
FAISS_THROW_IF_NOT(code_size % 4 == 0);
|
35
36
|
RandomRotationMatrix* rr = new RandomRotationMatrix(d, nbit);
|
36
37
|
rr->init(1234);
|
37
38
|
vt = rr;
|
@@ -151,8 +152,8 @@ void binarize_with_freq(
|
|
151
152
|
memset(codes, 0, (nbit + 7) / 8);
|
152
153
|
for (size_t i = 0; i < nbit; i++) {
|
153
154
|
float xf = (x[i] - c[i]);
|
154
|
-
|
155
|
-
|
155
|
+
int64_t xi = int64_t(floor(xf * freq));
|
156
|
+
int64_t bit = xi & 1;
|
156
157
|
codes[i >> 3] |= bit << (i & 7);
|
157
158
|
}
|
158
159
|
}
|
@@ -167,35 +168,33 @@ void IndexIVFSpectralHash::encode_vectors(
|
|
167
168
|
bool include_listnos) const {
|
168
169
|
FAISS_THROW_IF_NOT(is_trained);
|
169
170
|
float freq = 2.0 / period;
|
170
|
-
|
171
|
-
FAISS_THROW_IF_NOT_MSG(!include_listnos, "listnos encoding not supported");
|
171
|
+
size_t coarse_size = include_listnos ? coarse_code_size() : 0;
|
172
172
|
|
173
173
|
// transform with vt
|
174
174
|
std::unique_ptr<float[]> x(vt->apply(n, x_in));
|
175
175
|
|
176
|
-
|
177
|
-
{
|
178
|
-
std::vector<float> zero(nbit);
|
176
|
+
std::vector<float> zero(nbit);
|
179
177
|
|
180
|
-
// each thread takes care of a subset of lists
|
181
178
|
#pragma omp for
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
c,
|
197
|
-
codes + i * code_size);
|
179
|
+
for (idx_t i = 0; i < n; i++) {
|
180
|
+
int64_t list_no = list_nos[i];
|
181
|
+
uint8_t* code = codes + i * (code_size + coarse_size);
|
182
|
+
|
183
|
+
if (list_no >= 0) {
|
184
|
+
if (coarse_size) {
|
185
|
+
encode_listno(list_no, code);
|
186
|
+
}
|
187
|
+
const float* c;
|
188
|
+
|
189
|
+
if (threshold_type == Thresh_global) {
|
190
|
+
c = zero.data();
|
191
|
+
} else {
|
192
|
+
c = trained.data() + list_no * nbit;
|
198
193
|
}
|
194
|
+
binarize_with_freq(
|
195
|
+
nbit, freq, x.get() + i * nbit, c, code + coarse_size);
|
196
|
+
} else {
|
197
|
+
memset(code, 0, code_size + coarse_size);
|
199
198
|
}
|
200
199
|
}
|
201
200
|
}
|
@@ -206,9 +205,7 @@ template <class HammingComputer>
|
|
206
205
|
struct IVFScanner : InvertedListScanner {
|
207
206
|
// copied from index structure
|
208
207
|
const IndexIVFSpectralHash* index;
|
209
|
-
size_t code_size;
|
210
208
|
size_t nbit;
|
211
|
-
bool store_pairs;
|
212
209
|
|
213
210
|
float period, freq;
|
214
211
|
std::vector<float> q;
|
@@ -220,15 +217,16 @@ struct IVFScanner : InvertedListScanner {
|
|
220
217
|
|
221
218
|
IVFScanner(const IndexIVFSpectralHash* index, bool store_pairs)
|
222
219
|
: index(index),
|
223
|
-
code_size(index->code_size),
|
224
220
|
nbit(index->nbit),
|
225
|
-
store_pairs(store_pairs),
|
226
221
|
period(index->period),
|
227
222
|
freq(2.0 / index->period),
|
228
223
|
q(nbit),
|
229
224
|
zero(nbit),
|
230
|
-
qcode(code_size),
|
231
|
-
hc(qcode.data(), code_size) {
|
225
|
+
qcode(index->code_size),
|
226
|
+
hc(qcode.data(), index->code_size) {
|
227
|
+
this->store_pairs = store_pairs;
|
228
|
+
this->code_size = index->code_size;
|
229
|
+
}
|
232
230
|
|
233
231
|
void set_query(const float* query) override {
|
234
232
|
FAISS_THROW_IF_NOT(query);
|
@@ -241,8 +239,6 @@ struct IVFScanner : InvertedListScanner {
|
|
241
239
|
}
|
242
240
|
}
|
243
241
|
|
244
|
-
idx_t list_no;
|
245
|
-
|
246
242
|
void set_list(idx_t list_no, float /*coarse_dis*/) override {
|
247
243
|
this->list_no = list_no;
|
248
244
|
if (index->threshold_type != IndexIVFSpectralHash::Thresh_global) {
|
@@ -310,13 +306,38 @@ InvertedListScanner* IndexIVFSpectralHash::get_InvertedListScanner(
|
|
310
306
|
HANDLE_CODE_SIZE(64);
|
311
307
|
#undef HANDLE_CODE_SIZE
|
312
308
|
default:
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
309
|
+
return new IVFScanner<HammingComputerDefault>(this, store_pairs);
|
310
|
+
}
|
311
|
+
}
|
312
|
+
|
313
|
+
void IndexIVFSpectralHash::replace_vt(VectorTransform* vt_in, bool own) {
|
314
|
+
FAISS_THROW_IF_NOT(vt_in->d_out == nbit);
|
315
|
+
FAISS_THROW_IF_NOT(vt_in->d_in == d);
|
316
|
+
if (own_fields) {
|
317
|
+
delete vt;
|
319
318
|
}
|
319
|
+
vt = vt_in;
|
320
|
+
threshold_type = Thresh_global;
|
321
|
+
is_trained = quantizer->is_trained && quantizer->ntotal == nlist &&
|
322
|
+
vt->is_trained;
|
323
|
+
own_fields = own;
|
324
|
+
}
|
325
|
+
|
326
|
+
/*
|
327
|
+
Check that the encoder is a single vector transform followed by a LSH
|
328
|
+
that just does thresholding.
|
329
|
+
If this is not the case, the linear transform + threhsolds of the IndexLSH
|
330
|
+
should be merged into the VectorTransform (which is feasible).
|
331
|
+
*/
|
332
|
+
|
333
|
+
void IndexIVFSpectralHash::replace_vt(IndexPreTransform* encoder, bool own) {
|
334
|
+
FAISS_THROW_IF_NOT(encoder->chain.size() == 1);
|
335
|
+
auto sub_index = dynamic_cast<IndexLSH*>(encoder->index);
|
336
|
+
FAISS_THROW_IF_NOT_MSG(sub_index, "final index should be LSH");
|
337
|
+
FAISS_THROW_IF_NOT(sub_index->nbits == nbit);
|
338
|
+
FAISS_THROW_IF_NOT(!sub_index->rotate_data);
|
339
|
+
FAISS_THROW_IF_NOT(!sub_index->train_thresholds);
|
340
|
+
replace_vt(encoder->chain[0], own);
|
320
341
|
}
|
321
342
|
|
322
343
|
} // namespace faiss
|
@@ -17,6 +17,7 @@
|
|
17
17
|
namespace faiss {
|
18
18
|
|
19
19
|
struct VectorTransform;
|
20
|
+
struct IndexPreTransform;
|
20
21
|
|
21
22
|
/** Inverted list that stores binary codes of size nbit. Before the
|
22
23
|
* binary conversion, the dimension of the vectors is transformed from
|
@@ -25,23 +26,29 @@ struct VectorTransform;
|
|
25
26
|
* Each coordinate is subtracted from a value determined by
|
26
27
|
* threshold_type, and split into intervals of size period. Half of
|
27
28
|
* the interval is a 0 bit, the other half a 1.
|
29
|
+
*
|
28
30
|
*/
|
29
31
|
struct IndexIVFSpectralHash : IndexIVF {
|
30
|
-
|
32
|
+
/// transformation from d to nbit dim
|
33
|
+
VectorTransform* vt;
|
34
|
+
/// own the vt
|
31
35
|
bool own_fields;
|
32
36
|
|
37
|
+
/// nb of bits of the binary signature
|
33
38
|
int nbit;
|
39
|
+
/// interval size for 0s and 1s
|
34
40
|
float period;
|
35
41
|
|
36
42
|
enum ThresholdType {
|
37
|
-
Thresh_global,
|
38
|
-
Thresh_centroid,
|
39
|
-
Thresh_centroid_half,
|
40
|
-
Thresh_median
|
43
|
+
Thresh_global, ///< global threshold at 0
|
44
|
+
Thresh_centroid, ///< compare to centroid
|
45
|
+
Thresh_centroid_half, ///< central interval around centroid
|
46
|
+
Thresh_median ///< median of training set
|
41
47
|
};
|
42
48
|
ThresholdType threshold_type;
|
43
49
|
|
44
|
-
|
50
|
+
/// Trained threshold.
|
51
|
+
/// size nlist * nbit or 0 if Thresh_global
|
45
52
|
std::vector<float> trained;
|
46
53
|
|
47
54
|
IndexIVFSpectralHash(
|
@@ -65,6 +72,14 @@ struct IndexIVFSpectralHash : IndexIVF {
|
|
65
72
|
InvertedListScanner* get_InvertedListScanner(
|
66
73
|
bool store_pairs) const override;
|
67
74
|
|
75
|
+
/** replace the vector transform for an empty (and possibly untrained) index
|
76
|
+
*/
|
77
|
+
void replace_vt(VectorTransform* vt, bool own = false);
|
78
|
+
|
79
|
+
/** convenience function to get the VT from an index constucted by an
|
80
|
+
* index_factory (should end in "LSH") */
|
81
|
+
void replace_vt(IndexPreTransform* index, bool own = false);
|
82
|
+
|
68
83
|
~IndexIVFSpectralHash() override;
|
69
84
|
};
|
70
85
|
|
@@ -5,8 +5,6 @@
|
|
5
5
|
* LICENSE file in the root directory of this source tree.
|
6
6
|
*/
|
7
7
|
|
8
|
-
// -*- c++ -*-
|
9
|
-
|
10
8
|
#include <faiss/IndexLSH.h>
|
11
9
|
|
12
10
|
#include <cstdio>
|
@@ -25,15 +23,13 @@ namespace faiss {
|
|
25
23
|
***************************************************************/
|
26
24
|
|
27
25
|
IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds)
|
28
|
-
:
|
26
|
+
: IndexFlatCodes((nbits + 7) / 8, d),
|
29
27
|
nbits(nbits),
|
30
28
|
rotate_data(rotate_data),
|
31
29
|
train_thresholds(train_thresholds),
|
32
30
|
rrot(d, nbits) {
|
33
31
|
is_trained = !train_thresholds;
|
34
32
|
|
35
|
-
bytes_per_vec = (nbits + 7) / 8;
|
36
|
-
|
37
33
|
if (rotate_data) {
|
38
34
|
rrot.init(5);
|
39
35
|
} else {
|
@@ -41,11 +37,7 @@ IndexLSH::IndexLSH(idx_t d, int nbits, bool rotate_data, bool train_thresholds)
|
|
41
37
|
}
|
42
38
|
}
|
43
39
|
|
44
|
-
IndexLSH::IndexLSH()
|
45
|
-
: nbits(0),
|
46
|
-
bytes_per_vec(0),
|
47
|
-
rotate_data(false),
|
48
|
-
train_thresholds(false) {}
|
40
|
+
IndexLSH::IndexLSH() : nbits(0), rotate_data(false), train_thresholds(false) {}
|
49
41
|
|
50
42
|
const float* IndexLSH::apply_preprocess(idx_t n, const float* x) const {
|
51
43
|
float* xt = nullptr;
|
@@ -106,15 +98,6 @@ void IndexLSH::train(idx_t n, const float* x) {
|
|
106
98
|
is_trained = true;
|
107
99
|
}
|
108
100
|
|
109
|
-
void IndexLSH::add(idx_t n, const float* x) {
|
110
|
-
FAISS_THROW_IF_NOT(is_trained);
|
111
|
-
codes.resize((ntotal + n) * bytes_per_vec);
|
112
|
-
|
113
|
-
sa_encode(n, x, &codes[ntotal * bytes_per_vec]);
|
114
|
-
|
115
|
-
ntotal += n;
|
116
|
-
}
|
117
|
-
|
118
101
|
void IndexLSH::search(
|
119
102
|
idx_t n,
|
120
103
|
const float* x,
|
@@ -127,7 +110,7 @@ void IndexLSH::search(
|
|
127
110
|
const float* xt = apply_preprocess(n, x);
|
128
111
|
ScopeDeleter<float> del(xt == x ? nullptr : xt);
|
129
112
|
|
130
|
-
uint8_t* qcodes = new uint8_t[n *
|
113
|
+
uint8_t* qcodes = new uint8_t[n * code_size];
|
131
114
|
ScopeDeleter<uint8_t> del2(qcodes);
|
132
115
|
|
133
116
|
fvecs2bitvecs(xt, qcodes, nbits, n);
|
@@ -137,7 +120,7 @@ void IndexLSH::search(
|
|
137
120
|
|
138
121
|
int_maxheap_array_t res = {size_t(n), size_t(k), labels, idistances};
|
139
122
|
|
140
|
-
hammings_knn_hc(&res, qcodes, codes.data(), ntotal,
|
123
|
+
hammings_knn_hc(&res, qcodes, codes.data(), ntotal, code_size, true);
|
141
124
|
|
142
125
|
// convert distances to floats
|
143
126
|
for (int i = 0; i < k * n; i++)
|
@@ -158,15 +141,6 @@ void IndexLSH::transfer_thresholds(LinearTransform* vt) {
|
|
158
141
|
thresholds.clear();
|
159
142
|
}
|
160
143
|
|
161
|
-
void IndexLSH::reset() {
|
162
|
-
codes.clear();
|
163
|
-
ntotal = 0;
|
164
|
-
}
|
165
|
-
|
166
|
-
size_t IndexLSH::sa_code_size() const {
|
167
|
-
return bytes_per_vec;
|
168
|
-
}
|
169
|
-
|
170
144
|
void IndexLSH::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
|
171
145
|
FAISS_THROW_IF_NOT(is_trained);
|
172
146
|
const float* xt = apply_preprocess(n, x);
|
@@ -12,17 +12,14 @@
|
|
12
12
|
|
13
13
|
#include <vector>
|
14
14
|
|
15
|
-
#include <faiss/
|
15
|
+
#include <faiss/IndexFlatCodes.h>
|
16
16
|
#include <faiss/VectorTransform.h>
|
17
17
|
|
18
18
|
namespace faiss {
|
19
19
|
|
20
20
|
/** The sign of each vector component is put in a binary signature */
|
21
|
-
struct IndexLSH :
|
22
|
-
typedef unsigned char uint8_t;
|
23
|
-
|
21
|
+
struct IndexLSH : IndexFlatCodes {
|
24
22
|
int nbits; ///< nb of bits per vector
|
25
|
-
int bytes_per_vec; ///< nb of 8-bits per encoded vector
|
26
23
|
bool rotate_data; ///< whether to apply a random rotation to input
|
27
24
|
bool train_thresholds; ///< whether we train thresholds or use 0
|
28
25
|
|
@@ -30,9 +27,6 @@ struct IndexLSH : Index {
|
|
30
27
|
|
31
28
|
std::vector<float> thresholds; ///< thresholds to compare with
|
32
29
|
|
33
|
-
/// encoded dataset
|
34
|
-
std::vector<uint8_t> codes;
|
35
|
-
|
36
30
|
IndexLSH(
|
37
31
|
idx_t d,
|
38
32
|
int nbits,
|
@@ -50,8 +44,6 @@ struct IndexLSH : Index {
|
|
50
44
|
|
51
45
|
void train(idx_t n, const float* x) override;
|
52
46
|
|
53
|
-
void add(idx_t n, const float* x) override;
|
54
|
-
|
55
47
|
void search(
|
56
48
|
idx_t n,
|
57
49
|
const float* x,
|
@@ -59,8 +51,6 @@ struct IndexLSH : Index {
|
|
59
51
|
float* distances,
|
60
52
|
idx_t* labels) const override;
|
61
53
|
|
62
|
-
void reset() override;
|
63
|
-
|
64
54
|
/// transfer the thresholds to a pre-processing stage (and unset
|
65
55
|
/// train_thresholds)
|
66
56
|
void transfer_thresholds(LinearTransform* vt);
|
@@ -72,9 +62,6 @@ struct IndexLSH : Index {
|
|
72
62
|
/* standalone codec interface.
|
73
63
|
*
|
74
64
|
* The vectors are decoded to +/- 1 (not 0, 1) */
|
75
|
-
|
76
|
-
size_t sa_code_size() const override;
|
77
|
-
|
78
65
|
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
79
66
|
|
80
67
|
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
@@ -167,9 +167,7 @@ void IndexNNDescent::search(
|
|
167
167
|
float* simi = distances + i * k;
|
168
168
|
dis->set_query(x + i * d);
|
169
169
|
|
170
|
-
maxheap_heapify(k, simi, idxi);
|
171
170
|
nndescent.search(*dis, k, idxi, simi, vt);
|
172
|
-
maxheap_reorder(k, simi, idxi);
|
173
171
|
}
|
174
172
|
}
|
175
173
|
InterruptCallback::check();
|
@@ -28,12 +28,13 @@ namespace faiss {
|
|
28
28
|
********************************************************/
|
29
29
|
|
30
30
|
IndexPQ::IndexPQ(int d, size_t M, size_t nbits, MetricType metric)
|
31
|
-
:
|
31
|
+
: IndexFlatCodes(0, d, metric), pq(d, M, nbits) {
|
32
32
|
is_trained = false;
|
33
33
|
do_polysemous_training = false;
|
34
34
|
polysemous_ht = nbits * M + 1;
|
35
35
|
search_type = ST_PQ;
|
36
36
|
encode_signs = false;
|
37
|
+
code_size = pq.code_size;
|
37
38
|
}
|
38
39
|
|
39
40
|
IndexPQ::IndexPQ() {
|
@@ -69,53 +70,6 @@ void IndexPQ::train(idx_t n, const float* x) {
|
|
69
70
|
is_trained = true;
|
70
71
|
}
|
71
72
|
|
72
|
-
void IndexPQ::add(idx_t n, const float* x) {
|
73
|
-
FAISS_THROW_IF_NOT(is_trained);
|
74
|
-
codes.resize((n + ntotal) * pq.code_size);
|
75
|
-
pq.compute_codes(x, &codes[ntotal * pq.code_size], n);
|
76
|
-
ntotal += n;
|
77
|
-
}
|
78
|
-
|
79
|
-
size_t IndexPQ::remove_ids(const IDSelector& sel) {
|
80
|
-
idx_t j = 0;
|
81
|
-
for (idx_t i = 0; i < ntotal; i++) {
|
82
|
-
if (sel.is_member(i)) {
|
83
|
-
// should be removed
|
84
|
-
} else {
|
85
|
-
if (i > j) {
|
86
|
-
memmove(&codes[pq.code_size * j],
|
87
|
-
&codes[pq.code_size * i],
|
88
|
-
pq.code_size);
|
89
|
-
}
|
90
|
-
j++;
|
91
|
-
}
|
92
|
-
}
|
93
|
-
size_t nremove = ntotal - j;
|
94
|
-
if (nremove > 0) {
|
95
|
-
ntotal = j;
|
96
|
-
codes.resize(ntotal * pq.code_size);
|
97
|
-
}
|
98
|
-
return nremove;
|
99
|
-
}
|
100
|
-
|
101
|
-
void IndexPQ::reset() {
|
102
|
-
codes.clear();
|
103
|
-
ntotal = 0;
|
104
|
-
}
|
105
|
-
|
106
|
-
void IndexPQ::reconstruct_n(idx_t i0, idx_t ni, float* recons) const {
|
107
|
-
FAISS_THROW_IF_NOT(ni == 0 || (i0 >= 0 && i0 + ni <= ntotal));
|
108
|
-
for (idx_t i = 0; i < ni; i++) {
|
109
|
-
const uint8_t* code = &codes[(i0 + i) * pq.code_size];
|
110
|
-
pq.decode(code, recons + i * d);
|
111
|
-
}
|
112
|
-
}
|
113
|
-
|
114
|
-
void IndexPQ::reconstruct(idx_t key, float* recons) const {
|
115
|
-
FAISS_THROW_IF_NOT(key >= 0 && key < ntotal);
|
116
|
-
pq.decode(&codes[key * pq.code_size], recons);
|
117
|
-
}
|
118
|
-
|
119
73
|
namespace {
|
120
74
|
|
121
75
|
template <class PQDecoder>
|
@@ -457,9 +411,6 @@ void IndexPQ::search_core_polysemous(
|
|
457
411
|
}
|
458
412
|
|
459
413
|
/* The standalone codec interface (just remaps to the PQ functions) */
|
460
|
-
size_t IndexPQ::sa_code_size() const {
|
461
|
-
return pq.code_size;
|
462
|
-
}
|
463
414
|
|
464
415
|
void IndexPQ::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
|
465
416
|
pq.compute_codes(x, bytes, n);
|
@@ -12,7 +12,7 @@
|
|
12
12
|
|
13
13
|
#include <vector>
|
14
14
|
|
15
|
-
#include <faiss/
|
15
|
+
#include <faiss/IndexFlatCodes.h>
|
16
16
|
#include <faiss/impl/PolysemousTraining.h>
|
17
17
|
#include <faiss/impl/ProductQuantizer.h>
|
18
18
|
#include <faiss/impl/platform_macros.h>
|
@@ -21,13 +21,10 @@ namespace faiss {
|
|
21
21
|
|
22
22
|
/** Index based on a product quantizer. Stored vectors are
|
23
23
|
* approximated by PQ codes. */
|
24
|
-
struct IndexPQ :
|
24
|
+
struct IndexPQ : IndexFlatCodes {
|
25
25
|
/// The product quantizer used to encode the vectors
|
26
26
|
ProductQuantizer pq;
|
27
27
|
|
28
|
-
/// Codes. Size ntotal * pq.code_size
|
29
|
-
std::vector<uint8_t> codes;
|
30
|
-
|
31
28
|
/** Constructor.
|
32
29
|
*
|
33
30
|
* @param d dimensionality of the input vectors
|
@@ -43,8 +40,6 @@ struct IndexPQ : Index {
|
|
43
40
|
|
44
41
|
void train(idx_t n, const float* x) override;
|
45
42
|
|
46
|
-
void add(idx_t n, const float* x) override;
|
47
|
-
|
48
43
|
void search(
|
49
44
|
idx_t n,
|
50
45
|
const float* x,
|
@@ -52,17 +47,7 @@ struct IndexPQ : Index {
|
|
52
47
|
float* distances,
|
53
48
|
idx_t* labels) const override;
|
54
49
|
|
55
|
-
void reset() override;
|
56
|
-
|
57
|
-
void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
|
58
|
-
|
59
|
-
void reconstruct(idx_t key, float* recons) const override;
|
60
|
-
|
61
|
-
size_t remove_ids(const IDSelector& sel) override;
|
62
|
-
|
63
50
|
/* The standalone codec interface */
|
64
|
-
size_t sa_code_size() const override;
|
65
|
-
|
66
51
|
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
67
52
|
|
68
53
|
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
@@ -155,6 +155,34 @@ void IndexRefine::reconstruct(idx_t key, float* recons) const {
|
|
155
155
|
refine_index->reconstruct(key, recons);
|
156
156
|
}
|
157
157
|
|
158
|
+
size_t IndexRefine::sa_code_size() const {
|
159
|
+
return base_index->sa_code_size() + refine_index->sa_code_size();
|
160
|
+
}
|
161
|
+
|
162
|
+
void IndexRefine::sa_encode(idx_t n, const float* x, uint8_t* bytes) const {
|
163
|
+
size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
|
164
|
+
std::unique_ptr<uint8_t[]> tmp1(new uint8_t[n * cs1]);
|
165
|
+
base_index->sa_encode(n, x, tmp1.get());
|
166
|
+
std::unique_ptr<uint8_t[]> tmp2(new uint8_t[n * cs2]);
|
167
|
+
refine_index->sa_encode(n, x, tmp2.get());
|
168
|
+
for (size_t i = 0; i < n; i++) {
|
169
|
+
uint8_t* b = bytes + i * (cs1 + cs2);
|
170
|
+
memcpy(b, tmp1.get() + cs1 * i, cs1);
|
171
|
+
memcpy(b + cs1, tmp2.get() + cs2 * i, cs2);
|
172
|
+
}
|
173
|
+
}
|
174
|
+
|
175
|
+
void IndexRefine::sa_decode(idx_t n, const uint8_t* bytes, float* x) const {
|
176
|
+
size_t cs1 = base_index->sa_code_size(), cs2 = refine_index->sa_code_size();
|
177
|
+
std::unique_ptr<uint8_t[]> tmp2(
|
178
|
+
new uint8_t[n * refine_index->sa_code_size()]);
|
179
|
+
for (size_t i = 0; i < n; i++) {
|
180
|
+
memcpy(tmp2.get() + i * cs2, bytes + i * (cs1 + cs2), cs2);
|
181
|
+
}
|
182
|
+
|
183
|
+
refine_index->sa_decode(n, tmp2.get(), x);
|
184
|
+
}
|
185
|
+
|
158
186
|
IndexRefine::~IndexRefine() {
|
159
187
|
if (own_fields)
|
160
188
|
delete base_index;
|
@@ -49,6 +49,16 @@ struct IndexRefine : Index {
|
|
49
49
|
// reconstruct is routed to the refine_index
|
50
50
|
void reconstruct(idx_t key, float* recons) const override;
|
51
51
|
|
52
|
+
/* standalone codec interface: the base_index codes are interleaved with the
|
53
|
+
* refine_index ones */
|
54
|
+
size_t sa_code_size() const override;
|
55
|
+
|
56
|
+
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
57
|
+
|
58
|
+
/// The sa_decode decodes from the index_refine, which is assumed to be more
|
59
|
+
/// accurate
|
60
|
+
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|
61
|
+
|
52
62
|
~IndexRefine() override;
|
53
63
|
};
|
54
64
|
|
@@ -29,7 +29,7 @@ IndexScalarQuantizer::IndexScalarQuantizer(
|
|
29
29
|
int d,
|
30
30
|
ScalarQuantizer::QuantizerType qtype,
|
31
31
|
MetricType metric)
|
32
|
-
:
|
32
|
+
: IndexFlatCodes(0, d, metric), sq(d, qtype) {
|
33
33
|
is_trained = qtype == ScalarQuantizer::QT_fp16 ||
|
34
34
|
qtype == ScalarQuantizer::QT_8bit_direct;
|
35
35
|
code_size = sq.code_size;
|
@@ -43,13 +43,6 @@ void IndexScalarQuantizer::train(idx_t n, const float* x) {
|
|
43
43
|
is_trained = true;
|
44
44
|
}
|
45
45
|
|
46
|
-
void IndexScalarQuantizer::add(idx_t n, const float* x) {
|
47
|
-
FAISS_THROW_IF_NOT(is_trained);
|
48
|
-
codes.resize((n + ntotal) * code_size);
|
49
|
-
sq.compute_codes(x, &codes[ntotal * code_size], n);
|
50
|
-
ntotal += n;
|
51
|
-
}
|
52
|
-
|
53
46
|
void IndexScalarQuantizer::search(
|
54
47
|
idx_t n,
|
55
48
|
const float* x,
|
@@ -67,6 +60,7 @@ void IndexScalarQuantizer::search(
|
|
67
60
|
InvertedListScanner* scanner =
|
68
61
|
sq.select_InvertedListScanner(metric_type, nullptr, true);
|
69
62
|
ScopeDeleter1<InvertedListScanner> del(scanner);
|
63
|
+
scanner->list_no = 0; // directly the list number
|
70
64
|
|
71
65
|
#pragma omp for
|
72
66
|
for (idx_t i = 0; i < n; i++) {
|
@@ -99,27 +93,7 @@ DistanceComputer* IndexScalarQuantizer::get_distance_computer() const {
|
|
99
93
|
return dc;
|
100
94
|
}
|
101
95
|
|
102
|
-
void IndexScalarQuantizer::reset() {
|
103
|
-
codes.clear();
|
104
|
-
ntotal = 0;
|
105
|
-
}
|
106
|
-
|
107
|
-
void IndexScalarQuantizer::reconstruct_n(idx_t i0, idx_t ni, float* recons)
|
108
|
-
const {
|
109
|
-
std::unique_ptr<ScalarQuantizer::Quantizer> squant(sq.select_quantizer());
|
110
|
-
for (size_t i = 0; i < ni; i++) {
|
111
|
-
squant->decode_vector(&codes[(i + i0) * code_size], recons + i * d);
|
112
|
-
}
|
113
|
-
}
|
114
|
-
|
115
|
-
void IndexScalarQuantizer::reconstruct(idx_t key, float* recons) const {
|
116
|
-
reconstruct_n(key, 1, recons);
|
117
|
-
}
|
118
|
-
|
119
96
|
/* Codec interface */
|
120
|
-
size_t IndexScalarQuantizer::sa_code_size() const {
|
121
|
-
return sq.code_size;
|
122
|
-
}
|
123
97
|
|
124
98
|
void IndexScalarQuantizer::sa_encode(idx_t n, const float* x, uint8_t* bytes)
|
125
99
|
const {
|
@@ -13,6 +13,7 @@
|
|
13
13
|
#include <stdint.h>
|
14
14
|
#include <vector>
|
15
15
|
|
16
|
+
#include <faiss/IndexFlatCodes.h>
|
16
17
|
#include <faiss/IndexIVF.h>
|
17
18
|
#include <faiss/impl/ScalarQuantizer.h>
|
18
19
|
|
@@ -24,15 +25,10 @@ namespace faiss {
|
|
24
25
|
* (default).
|
25
26
|
*/
|
26
27
|
|
27
|
-
struct IndexScalarQuantizer :
|
28
|
+
struct IndexScalarQuantizer : IndexFlatCodes {
|
28
29
|
/// Used to encode the vectors
|
29
30
|
ScalarQuantizer sq;
|
30
31
|
|
31
|
-
/// Codes. Size ntotal * pq.code_size
|
32
|
-
std::vector<uint8_t> codes;
|
33
|
-
|
34
|
-
size_t code_size;
|
35
|
-
|
36
32
|
/** Constructor.
|
37
33
|
*
|
38
34
|
* @param d dimensionality of the input vectors
|
@@ -48,8 +44,6 @@ struct IndexScalarQuantizer : Index {
|
|
48
44
|
|
49
45
|
void train(idx_t n, const float* x) override;
|
50
46
|
|
51
|
-
void add(idx_t n, const float* x) override;
|
52
|
-
|
53
47
|
void search(
|
54
48
|
idx_t n,
|
55
49
|
const float* x,
|
@@ -57,17 +51,9 @@ struct IndexScalarQuantizer : Index {
|
|
57
51
|
float* distances,
|
58
52
|
idx_t* labels) const override;
|
59
53
|
|
60
|
-
void reset() override;
|
61
|
-
|
62
|
-
void reconstruct_n(idx_t i0, idx_t ni, float* recons) const override;
|
63
|
-
|
64
|
-
void reconstruct(idx_t key, float* recons) const override;
|
65
|
-
|
66
54
|
DistanceComputer* get_distance_computer() const override;
|
67
55
|
|
68
56
|
/* standalone codec interface */
|
69
|
-
size_t sa_code_size() const override;
|
70
|
-
|
71
57
|
void sa_encode(idx_t n, const float* x, uint8_t* bytes) const override;
|
72
58
|
|
73
59
|
void sa_decode(idx_t n, const uint8_t* bytes, float* x) const override;
|